mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Merge pull request #2157 from kohya-ss/feature-chroma-support
Feature Chroma support
This commit is contained in:
13
README-ja.md
13
README-ja.md
@@ -155,11 +155,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--n` ネガティブプロンプト(次のオプションまで)
|
||||
* `--w` 生成画像の幅を指定
|
||||
* `--h` 生成画像の高さを指定
|
||||
* `--d` 生成画像のシード値を指定
|
||||
* `--l` 生成画像のCFGスケールを指定。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください
|
||||
* `--g` 埋め込みガイダンス付きモデル(FLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください
|
||||
* `--s` 生成時のステップ数を指定
|
||||
|
||||
`( )` や `[ ]` などの重みづけも動作します。
|
||||
|
||||
12
README.md
12
README.md
@@ -16,6 +16,13 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Jul 30, 2025:
|
||||
- **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details.
|
||||
|
||||
- Support for [Chroma](https://huggingface.co/lodestones/Chroma) has been added in PR [#2157](https://github.com/kohya-ss/sd-scripts/pull/2157). Thank you to lodestones for the high-quality model.
|
||||
- Chroma is a new model based on FLUX.1 schnell. In this repository, `flux_train_network.py` is used for training LoRAs for Chroma with `--model_type chroma`.
|
||||
- Please refer to the [FLUX.1 LoRA training documentation](./docs/flux_train_network.md) for more details.
|
||||
|
||||
Jul 21, 2025:
|
||||
- Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions.
|
||||
- Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details.
|
||||
@@ -1367,9 +1374,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
|
||||
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
|
||||
* `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG.
|
||||
* `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
@@ -4,6 +4,13 @@ Status: reviewed
|
||||
|
||||
This document explains how to train LoRA models for the FLUX.1 model using `flux_train_network.py` included in the `sd-scripts` repository.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
このドキュメントでは、`sd-scripts`リポジトリに含まれる`flux_train_network.py`を使用して、FLUX.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
|
||||
|
||||
</details>
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`flux_train_network.py` trains additional networks such as LoRA on the FLUX.1 model, which uses a transformer-based architecture different from Stable Diffusion. Two text encoders, CLIP-L and T5-XXL, and a dedicated AutoEncoder are used.
|
||||
@@ -15,21 +22,103 @@ This guide assumes you know the basics of LoRA training. For common options see
|
||||
* The repository is cloned and the Python environment is ready.
|
||||
* A training dataset is prepared. See the dataset configuration guide.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`flux_train_network.py`は、FLUX.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。FLUX.1はStable Diffusionとは異なるアーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
**前提条件:**
|
||||
|
||||
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)
|
||||
|
||||
</details>
|
||||
|
||||
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||
|
||||
`flux_train_network.py` is based on `train_network.py` but adapted for FLUX.1. Main differences include required arguments for the FLUX.1 model, CLIP-L, T5-XXL and AE, different model structure, and some incompatible options from Stable Diffusion.
|
||||
`flux_train_network.py` is based on `train_network.py` but adapted for FLUX.1. Main differences include:
|
||||
|
||||
* **Target model:** FLUX.1 model (dev or schnell version).
|
||||
* **Model structure:** Unlike Stable Diffusion, FLUX.1 uses a Transformer-based architecture with two text encoders (CLIP-L and T5-XXL) and a dedicated AutoEncoder (AE) instead of VAE.
|
||||
* **Required arguments:** Additional arguments for FLUX.1 model, CLIP-L, T5-XXL, and AE model files.
|
||||
* **Incompatible options:** Some Stable Diffusion-specific arguments (e.g., `--v2`, `--clip_skip`, `--max_token_length`) are not used in FLUX.1 training.
|
||||
* **FLUX.1-specific arguments:** Additional arguments for FLUX.1-specific training parameters like timestep sampling and guidance scale.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`flux_train_network.py`は`train_network.py`をベースに、FLUX.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** FLUX.1モデル(dev版またはschnell版)を対象とします。
|
||||
* **モデル構造:** Stable Diffusionとは異なり、FLUX.1はTransformerベースのアーキテクチャを持ちます。Text EncoderとしてCLIP-LとT5-XXLの二つを使用し、VAEの代わりに専用のAutoEncoder (AE) を使用します。
|
||||
* **必須の引数:** FLUX.1モデル、CLIP-L、T5-XXL、AEの各モデルファイルを指定する引数が追加されています。
|
||||
* **一部引数の非互換性:** Stable Diffusion向けの引数の一部(例: `--v2`, `--clip_skip`, `--max_token_length`)はFLUX.1の学習では使用されません。
|
||||
* **FLUX.1特有の引数:** タイムステップのサンプリング方法やガイダンススケールなど、FLUX.1特有の学習パラメータを指定する引数が追加されています。
|
||||
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
|
||||
Before starting training you need:
|
||||
|
||||
1. **Training script:** `flux_train_network.py`
|
||||
2. **FLUX.1 model file** and text encoder files (`clip_l`, `t5xxl`) and AE file.
|
||||
3. **Dataset definition file (.toml)** such as `my_flux_dataset_config.toml`.
|
||||
2. **FLUX.1 model file:** Base FLUX.1 model `.safetensors` file (e.g., `flux1-dev.safetensors`).
|
||||
3. **Text Encoder model files:**
|
||||
- CLIP-L model `.safetensors` file (e.g., `clip_l.safetensors`)
|
||||
- T5-XXL model `.safetensors` file (e.g., `t5xxl.safetensors`)
|
||||
4. **AutoEncoder model file:** FLUX.1-compatible AE model `.safetensors` file (e.g., `ae.safetensors`).
|
||||
5. **Dataset definition file (.toml):** TOML format file describing training dataset configuration (e.g., `my_flux_dataset_config.toml`).
|
||||
|
||||
### Downloading Required Models
|
||||
|
||||
To train FLUX.1 models, you need to download the following model files:
|
||||
|
||||
- **DiT, AE**: Download from the [black-forest-labs/FLUX.1 dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) repository. Use `flux1-dev.safetensors` and `ae.safetensors`. The weights in the subfolder are in Diffusers format and cannot be used.
|
||||
- **Text Encoder 1 (T5-XXL), Text Encoder 2 (CLIP-L)**: Download from the [ComfyUI FLUX Text Encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) repository. Please use `t5xxl_fp16.safetensors` for T5-XXL. Thanks to ComfyUI for providing these models.
|
||||
|
||||
To train Chroma models, you need to download the Chroma model file from the following repository:
|
||||
|
||||
- **Chroma Base**: Download from the [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) repository. Use `Chroma.safetensors`.
|
||||
|
||||
We have tested Chroma training with the weights from the [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) repository.
|
||||
|
||||
AE and T5-XXL models are same as FLUX.1, so you can use the same files. CLIP-L model is not used for Chroma training, so you can omit the `--clip_l` argument.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習を開始する前に、以下のファイルが必要です。
|
||||
|
||||
1. **学習スクリプト:** `flux_train_network.py`
|
||||
2. **FLUX.1モデルファイル:** 学習のベースとなるFLUX.1モデルの`.safetensors`ファイル(例: `flux1-dev.safetensors`)。
|
||||
3. **Text Encoderモデルファイル:**
|
||||
- CLIP-Lモデルの`.safetensors`ファイル。例として`clip_l.safetensors`を使用します。
|
||||
- T5-XXLモデルの`.safetensors`ファイル。例として`t5xxl.safetensors`を使用します。
|
||||
4. **AutoEncoderモデルファイル:** FLUX.1に対応するAEモデルの`.safetensors`ファイル。例として`ae.safetensors`を使用します。
|
||||
5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。例として`my_flux_dataset_config.toml`を使用します。
|
||||
|
||||
**必要なモデルのダウンロード**
|
||||
|
||||
FLUX.1モデルを学習するためには、以下のモデルファイルをダウンロードする必要があります。
|
||||
|
||||
- **DiT, AE**: [black-forest-labs/FLUX.1 dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) リポジトリからダウンロードします。`flux1-dev.safetensors`と`ae.safetensors`を使用してください。サブフォルダ内の重みはDiffusers形式であり、使用できません。
|
||||
- **Text Encoder 1 (T5-XXL), Text Encoder 2 (CLIP-L)**: [ComfyUI FLUX Text Encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) リポジトリからダウンロードします。T5-XXLには`t5xxl_fp16.safetensors`を使用してください。これらのモデルを提供いただいたComfyUIに感謝します。
|
||||
|
||||
Chromaモデルを学習する場合は、以下のリポジトリからChromaモデルファイルをダウンロードする必要があります。
|
||||
|
||||
- **Chroma Base**: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) リポジトリからダウンロードします。`Chroma.safetensors`を使用してください。
|
||||
|
||||
Chromaの学習のテストは [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) リポジトリの重みを使用して行いました。
|
||||
|
||||
AEとT5-XXLモデルはFLUX.1と同じものを使用できるため、同じファイルを使用します。CLIP-LモデルはChroma学習では使用されないため、`--clip_l`引数は省略できます。
|
||||
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
|
||||
Run `flux_train_network.py` from the terminal with FLUX.1 specific arguments. Example:
|
||||
Run `flux_train_network.py` from the terminal with FLUX.1 specific arguments. Here's a basic command example:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \
|
||||
@@ -54,369 +143,369 @@ accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \
|
||||
--gradient_checkpointing \
|
||||
--guidance_scale=1.0 \
|
||||
--timestep_sampling="flux_shift" \
|
||||
--model_prediction_type="raw" \
|
||||
--blocks_to_swap=18 \
|
||||
--cache_text_encoder_outputs \
|
||||
--cache_latents
|
||||
```
|
||||
|
||||
### Training Chroma Models
|
||||
|
||||
If you want to train a Chroma model, specify `--model_type=chroma`. Chroma does not use CLIP-L, so the `--clip_l` argument is not needed. T5XXL and AE are same as FLUX.1. The command would look like this:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \
|
||||
--pretrained_model_name_or_path="<path to Chroma model>" \
|
||||
--model_type=chroma \
|
||||
--t5xxl="<path to T5-XXL model>" \
|
||||
--ae="<path to AE model>" \
|
||||
--dataset_config="my_flux_dataset_config.toml" \
|
||||
--output_dir="<output directory>" \
|
||||
--output_name="my_chroma_lora" \
|
||||
--guidance_scale=0.0 \
|
||||
--timestep_sampling="sigmoid" \
|
||||
--apply_t5_attn_mask \
|
||||
...
|
||||
```
|
||||
|
||||
Note that for Chroma models, `--guidance_scale=0.0` is required to disable guidance scale, and `--apply_t5_attn_mask` is needed to apply attention masks for T5XXL Text Encoder.
|
||||
|
||||
The sample image generation during training requires specifying a negative prompt. Also, set `--g 0` to disable embedded guidance scale and `--l 4.0` to set the CFG scale. For example:
|
||||
|
||||
```
|
||||
Japanese shrine in the summer forest. --n low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors --w 512 --h 512 --d 1 --l 4.0 --g 0.0 --s 20
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
学習は、ターミナルから`flux_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、FLUX.1特有の引数を指定する必要があります。
|
||||
|
||||
コマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
#### Chromaモデルの学習
|
||||
|
||||
Chromaモデルを学習したい場合は、`--model_type=chroma`を指定します。ChromaはCLIP-Lを使用しないため、`--clip_l`引数は不要です。T5XXLとAEはFLUX.1と同様です。
|
||||
|
||||
コマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
学習中のサンプル画像生成には、ネガティブプロンプトを指定してください。また `--g 0` を指定して埋め込みガイダンススケールを無効化し、`--l 4.0` を指定してCFGスケールを設定します。
|
||||
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
|
||||
The script adds FLUX.1 specific arguments such as guidance scale, timestep sampling, block swapping, and options for training CLIP-L and T5-XXL LoRA modules. Some Stable Diffusion options like `--v2` and `--clip_skip` are not used.
|
||||
The script adds FLUX.1 specific arguments. For common arguments (like `--output_dir`, `--output_name`, `--network_module`, etc.), see the [`train_network.py` guide](train_network.md).
|
||||
|
||||
#### Model-related [Required]
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to FLUX.1/Chroma model>"` **[Required]**
|
||||
- Specifies the path to the base FLUX.1 or Chroma model `.safetensors` file. Diffusers format directories are not currently supported.
|
||||
* `--model_type=<model type>`
|
||||
- Specifies the type of base model for training. Choose from `flux` or `chroma`. Default is `flux`.
|
||||
* `--clip_l="<path to CLIP-L model>"` **[Required when flux is selected]**
|
||||
- Specifies the path to the CLIP-L Text Encoder model `.safetensors` file. Not needed when `--model_type=chroma`.
|
||||
* `--t5xxl="<path to T5-XXL model>"` **[Required]**
|
||||
- Specifies the path to the T5-XXL Text Encoder model `.safetensors` file.
|
||||
* `--ae="<path to AE model>"` **[Required]**
|
||||
- Specifies the path to the FLUX.1-compatible AutoEncoder model `.safetensors` file.
|
||||
|
||||
#### FLUX.1 Training Parameters
|
||||
|
||||
* `--guidance_scale=<float>`
|
||||
- FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `1.0` to disable guidance scale. Default is `3.5`, so be sure to specify this. Usually ignored for schnell version.
|
||||
- Chroma requires `--guidance_scale=0.0` to disable guidance scale.
|
||||
* `--timestep_sampling=<choice>`
|
||||
- Specifies the sampling method for timesteps (noise levels) during training. Choose from `sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`. Default is `sigma`. Recommended is `flux_shift`. For Chroma models, `sigmoid` is recommended.
|
||||
* `--sigmoid_scale=<float>`
|
||||
- Scale factor when `timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default and recommended value is `1.0`.
|
||||
* `--model_prediction_type=<choice>`
|
||||
- Specifies what the model predicts. Choose from `raw` (use prediction as-is), `additive` (add to noise input), `sigma_scaled` (apply sigma scaling). Default is `sigma_scaled`. Recommended is `raw`.
|
||||
* `--discrete_flow_shift=<float>`
|
||||
- Specifies the shift value for the scheduler used in Flow Matching. Default is `3.0`. This value is ignored when `timestep_sampling` is set to other than `shift`.
|
||||
|
||||
#### Memory/Speed Related
|
||||
|
||||
* `--fp8_base`
|
||||
- Enables training in FP8 format for FLUX.1, CLIP-L, and T5-XXL. This can significantly reduce VRAM usage, but the training results may vary.
|
||||
* `--blocks_to_swap=<integer>` **[Experimental Feature]**
|
||||
- Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`.
|
||||
- Cannot be used with `--cpu_offload_checkpointing`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
- Caches the outputs of CLIP-L and T5-XXL. This reduces memory usage.
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
- Caches the outputs of AE. Similar functionality to [sdxl_train_network.py](sdxl_train_network.md).
|
||||
|
||||
#### Incompatible/Deprecated Arguments
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip`: These are Stable Diffusion-specific arguments and are not used in FLUX.1 training.
|
||||
* `--max_token_length`: This is an argument for Stable Diffusion v1/v2. For FLUX.1, use `--t5xxl_max_token_length`.
|
||||
* `--split_mode`: Deprecated argument. Use `--blocks_to_swap` instead.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のFLUX.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。
|
||||
|
||||
コマンドラインの例と詳細な引数の説明は英語のドキュメントを参照してください。
|
||||
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
|
||||
Training begins once you run the command with the required options. Log checking is the same as in `train_network.py`.
|
||||
Training begins once you run the command with the required options. Log checking is the same as in [`train_network.py`](train_network.md#32-starting-the-training--学習の開始).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||
|
||||
</details>
|
||||
|
||||
## 5. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
After training, a LoRA model file is saved in `output_dir` and can be used in inference environments supporting FLUX.1 (e.g. ComfyUI + Flux nodes).
|
||||
|
||||
## 6. Others / その他
|
||||
|
||||
Additional notes on VRAM optimization, training options, multi-resolution datasets, block selection and text encoder LoRA are provided in the Japanese section.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
|
||||
|
||||
# `flux_train_network.py` を用いたFLUX.1モデルのLoRA学習ガイド
|
||||
|
||||
このドキュメントでは、`sd-scripts`リポジトリに含まれる`flux_train_network.py`を使用して、FLUX.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
|
||||
|
||||
## 1. はじめに
|
||||
|
||||
`flux_train_network.py`は、FLUX.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。FLUX.1はStable Diffusionとは異なるアーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
**前提条件:**
|
||||
|
||||
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)
|
||||
|
||||
## 2. `train_network.py` との違い
|
||||
|
||||
`flux_train_network.py`は`train_network.py`をベースに、FLUX.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** FLUX.1モデル(dev版またはschnell版)を対象とします。
|
||||
* **モデル構造:** Stable Diffusionとは異なり、FLUX.1はTransformerベースのアーキテクチャを持ちます。Text EncoderとしてCLIP-LとT5-XXLの二つを使用し、VAEの代わりに専用のAutoEncoder (AE) を使用します。
|
||||
* **必須の引数:** FLUX.1モデル、CLIP-L、T5-XXL、AEの各モデルファイルを指定する引数が追加されています。
|
||||
* **一部引数の非互換性:** Stable Diffusion向けの引数の一部(例: `--v2`, `--clip_skip`, `--max_token_length`)はFLUX.1の学習では使用されません。
|
||||
* **FLUX.1特有の引数:** タイムステップのサンプリング方法やガイダンススケールなど、FLUX.1特有の学習パラメータを指定する引数が追加されています。
|
||||
|
||||
## 3. 準備
|
||||
|
||||
学習を開始する前に、以下のファイルが必要です。
|
||||
|
||||
1. **学習スクリプト:** `flux_train_network.py`
|
||||
2. **FLUX.1モデルファイル:** 学習のベースとなるFLUX.1モデルの`.safetensors`ファイル(例: `flux1-dev.safetensors`)。
|
||||
3. **Text Encoderモデルファイル:**
|
||||
* CLIP-Lモデルの`.safetensors`ファイル。例として`clip_l.safetensors`を使用します。
|
||||
* T5-XXLモデルの`.safetensors`ファイル。例として`t5xxl.safetensors`を使用します。
|
||||
4. **AutoEncoderモデルファイル:** FLUX.1に対応するAEモデルの`.safetensors`ファイル。例として`ae.safetensors`を使用します。
|
||||
5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。
|
||||
|
||||
* 例として`my_flux_dataset_config.toml`を使用します。
|
||||
|
||||
## 4. 学習の実行
|
||||
|
||||
学習は、ターミナルから`flux_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、FLUX.1特有の引数を指定する必要があります。
|
||||
|
||||
以下に、基本的なコマンドライン実行例を示します。
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py
|
||||
--pretrained_model_name_or_path="<path to FLUX.1 model>"
|
||||
--clip_l="<path to CLIP-L model>"
|
||||
--t5xxl="<path to T5-XXL model>"
|
||||
--ae="<path to AE model>"
|
||||
--dataset_config="my_flux_dataset_config.toml"
|
||||
--output_dir="<output directory for training results>"
|
||||
--output_name="my_flux_lora"
|
||||
--save_model_as=safetensors
|
||||
--network_module=networks.lora_flux
|
||||
--network_dim=16
|
||||
--network_alpha=1
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--lr_scheduler="constant"
|
||||
--sdpa
|
||||
--max_train_epochs=10
|
||||
--save_every_n_epochs=1
|
||||
--mixed_precision="fp16"
|
||||
--gradient_checkpointing
|
||||
--guidance_scale=1.0
|
||||
--timestep_sampling="flux_shift"
|
||||
--blocks_to_swap=18
|
||||
--cache_text_encoder_outputs
|
||||
--cache_latents
|
||||
```
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
|
||||
### 4.1. 主要なコマンドライン引数の解説(`train_network.py`からの追加・変更点)
|
||||
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のFLUX.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。
|
||||
|
||||
#### モデル関連 [必須]
|
||||
|
||||
* `--pretrained_model_name_or_path="<path to FLUX.1 model>"` **[必須]**
|
||||
* 学習のベースとなるFLUX.1モデル(dev版またはschnell版)の`.safetensors`ファイルのパスを指定します。Diffusers形式のディレクトリは現在サポートされていません。
|
||||
* `--clip_l="<path to CLIP-L model>"` **[必須]**
|
||||
* CLIP-L Text Encoderモデルの`.safetensors`ファイルのパスを指定します。
|
||||
* `--t5xxl="<path to T5-XXL model>"` **[必須]**
|
||||
* T5-XXL Text Encoderモデルの`.safetensors`ファイルのパスを指定します。
|
||||
* `--ae="<path to AE model>"` **[必須]**
|
||||
* FLUX.1に対応するAutoEncoderモデルの`.safetensors`ファイルのパスを指定します。
|
||||
|
||||
#### FLUX.1 学習パラメータ
|
||||
|
||||
* `--guidance_scale=<float>`
|
||||
* FLUX.1 dev版は特定のガイダンススケール値で蒸留されていますが、学習時には `1.0` を指定してガイダンススケールを無効化します。デフォルトは`3.5`ですので、必ず指定してください。schnell版では通常無視されます。
|
||||
* `--timestep_sampling=<choice>`
|
||||
* 学習時に使用するタイムステップ(ノイズレベル)のサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift` から選択します。デフォルトは `sigma` です。推奨は `flux_shift` です。
|
||||
* `--sigmoid_scale=<float>`
|
||||
* `timestep_sampling` に `sigmoid` または `shift`, `flux_shift` を指定した場合のスケール係数です。デフォルトおよび推奨値は`1.0`です。
|
||||
* `--model_prediction_type=<choice>`
|
||||
* モデルが何を予測するかを指定します。`raw` (予測値をそのまま使用), `additive` (ノイズ入力に加算), `sigma_scaled` (シグマスケーリングを適用) から選択します。デフォルトは `sigma_scaled` です。推奨は `raw` です。
|
||||
* `--discrete_flow_shift=<float>`
|
||||
* Flow Matchingで使用されるスケジューラのシフト値を指定します。デフォルトは`3.0`です。`timestep_sampling`に`flux_shift`を指定した場合は、この値は無視されます。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[実験的機能]**
|
||||
* VRAM使用量を削減するために、モデルの一部(Transformerブロック)をCPUとGPU間でスワップする設定です。スワップするブロック数を整数で指定します(例: `18`)。値を大きくするとVRAM使用量は減りますが、学習速度は低下します。GPUのVRAM容量に応じて調整してください。`gradient_checkpointing`と併用可能です。
|
||||
* `--cpu_offload_checkpointing`とは併用できません。
|
||||
* `--cache_text_encoder_outputs`
|
||||
* CLIP-LおよびT5-XXLの出力をキャッシュします。これにより、メモリ使用量が削減されます。
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
* AEの出力をキャッシュします。[sdxl_train_network.py](sdxl_train_network.md)と同様の機能です。
|
||||
|
||||
#### 非互換・非推奨の引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip`: Stable Diffusion特有の引数のため、FLUX.1学習では使用されません。
|
||||
* `--max_token_length`: Stable Diffusion v1/v2向けの引数です。FLUX.1では`--t5xxl_max_token_length`を使用してください。
|
||||
* `--split_mode`: 非推奨の引数です。代わりに`--blocks_to_swap`を使用してください。
|
||||
|
||||
### 4.2. 学習の開始
|
||||
|
||||
必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||
|
||||
## 5. 学習済みモデルの利用
|
||||
|
||||
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_flux_lora.safetensors`)が保存されます。このファイルは、FLUX.1モデルに対応した推論環境(例: ComfyUI + ComfyUI-FluxNodes)で使用できます。
|
||||
|
||||
## 6. その他
|
||||
</details>
|
||||
|
||||
`flux_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python flux_train_network.py --help`) を参照してください。
|
||||
## 6. Advanced Settings / 高度な設定
|
||||
|
||||
# FLUX.1 LoRA学習の補足説明
|
||||
### 6.1. VRAM Usage Optimization / VRAM使用量の最適化
|
||||
|
||||
以下は、以上の基本的なFLUX.1 LoRAの学習手順を補足するものです。より詳細な設定オプションなどについて説明します。
|
||||
FLUX.1 is a relatively large model, so GPUs without sufficient VRAM require optimization. Here are settings to reduce VRAM usage (with `--fp8_base`):
|
||||
|
||||
## 1. VRAM使用量の最適化
|
||||
#### Recommended Settings by GPU Memory
|
||||
|
||||
FLUX.1モデルは比較的大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。以下に、VRAM使用量を削減するための設定を紹介します。
|
||||
| GPU Memory | Recommended Settings |
|
||||
|------------|---------------------|
|
||||
| 24GB VRAM | Basic settings work fine (batch size 2) |
|
||||
| 16GB VRAM | Set batch size to 1 and use `--blocks_to_swap` |
|
||||
| 12GB VRAM | Use `--blocks_to_swap 16` and 8bit AdamW |
|
||||
| 10GB VRAM | Use `--blocks_to_swap 22`, recommend fp8 format for T5XXL |
|
||||
| 8GB VRAM | Use `--blocks_to_swap 28`, recommend fp8 format for T5XXL |
|
||||
|
||||
### 1.1 メモリ使用量別の推奨設定
|
||||
#### Key VRAM Reduction Options
|
||||
|
||||
| GPUメモリ | 推奨設定 |
|
||||
|----------|----------|
|
||||
| 24GB VRAM | 基本設定で問題なく動作します(バッチサイズ2) |
|
||||
| 16GB VRAM | バッチサイズ1に設定し、`--blocks_to_swap`を使用 |
|
||||
| 12GB VRAM | `--blocks_to_swap 16`と8bit AdamWを使用 |
|
||||
| 10GB VRAM | `--blocks_to_swap 22`を使用、T5XXLはfp8形式を推奨 |
|
||||
| 8GB VRAM | `--blocks_to_swap 28`を使用、T5XXLはfp8形式を推奨 |
|
||||
- **`--fp8_base`**: Enables training in FP8 format.
|
||||
|
||||
### 1.2 主要なVRAM削減オプション
|
||||
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. FLUX.1 supports up to 35 blocks for swapping.
|
||||
|
||||
- **`--blocks_to_swap <数値>`**:
|
||||
CPUとGPU間でブロックをスワップしてVRAM使用量を削減します。数値が大きいほど多くのブロックをスワップし、より多くのVRAMを節約できますが、学習速度は低下します。FLUX.1では最大35ブロックまでスワップ可能です。
|
||||
- **`--cpu_offload_checkpointing`**: Offloads gradient checkpoints to CPU. Can reduce VRAM usage by up to 1GB but decreases training speed by about 15%. Cannot be used with `--blocks_to_swap`. Chroma models do not support this option.
|
||||
|
||||
- **`--cpu_offload_checkpointing`**:
|
||||
勾配チェックポイントをCPUにオフロードします。これにより最大1GBのVRAM使用量を削減できますが、学習速度は約15%低下します。`--blocks_to_swap`とは併用できません。
|
||||
|
||||
- **`--cache_text_encoder_outputs` / `--cache_text_encoder_outputs_to_disk`**:
|
||||
CLIP-LとT5-XXLの出力をキャッシュします。これによりメモリ使用量を削減できます。
|
||||
|
||||
- **`--cache_latents` / `--cache_latents_to_disk`**:
|
||||
AEの出力をキャッシュします。メモリ使用量を削減できます。
|
||||
|
||||
- **Adafactor オプティマイザの使用**:
|
||||
8bit AdamWよりもVRAM使用量を削減できます。以下の設定を使用してください:
|
||||
- **Using Adafactor optimizer**: Can reduce VRAM usage more than 8bit AdamW:
|
||||
```
|
||||
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
||||
```
|
||||
|
||||
- **T5XXLのfp8形式の使用**:
|
||||
10GB未満のVRAMを持つGPUでは、T5XXLのfp8形式チェックポイントの使用を推奨します。[comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders)から`t5xxl_fp8_e4m3fn.safetensors`をダウンロードできます(`scaled`なしで使用してください)。
|
||||
- **Using T5XXL fp8 format**: For GPUs with less than 10GB VRAM, using fp8 format T5XXL checkpoints is recommended. Download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (use without `scaled`).
|
||||
|
||||
- **FP8/FP16 混合学習 [実験的機能]**:
|
||||
`--fp8_base_unet` オプションを指定すると、FLUX.1モデル本体をFP8形式で学習し、Text Encoder (CLIP-L/T5XXL) をBF16/FP16形式で学習できます。これにより、さらにVRAM使用量を削減できる可能性があります。このオプションを指定すると、`--fp8_base` オプションも自動的に有効になります。
|
||||
- **FP8/FP16 Mixed Training [Experimental]**: Specify `--fp8_base_unet` to train the FLUX.1 model in FP8 format while training Text Encoders (CLIP-L/T5XXL) in BF16/FP16 format. This can further reduce VRAM usage.
|
||||
|
||||
- **`pytorch-optimizer` の利用**:
|
||||
`pytorch-optimizer` ライブラリに含まれる様々なオプティマイザを使用できます。`requirements.txt` に追加されているため、別途インストールは不要です。
|
||||
例えば、CAME オプティマイザを使用する場合は以下のように指定します。
|
||||
```bash
|
||||
--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"
|
||||
|
||||
## 2. FLUX.1 LoRA学習の重要な設定オプション
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
FLUX.1の学習には多くの未知の点があり、いくつかの設定は引数で指定できます。以下に重要な引数とその説明を示します。
|
||||
FLUX.1モデルは比較的大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。VRAM使用量を削減するための設定の詳細は英語のドキュメントを参照してください。
|
||||
|
||||
### 2.1 タイムステップのサンプリング方法
|
||||
主要なVRAM削減オプション:
|
||||
- `--fp8_base`: FP8形式での学習を有効化
|
||||
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
|
||||
- `--cpu_offload_checkpointing`: 勾配チェックポイントをCPUにオフロード
|
||||
- Adafactorオプティマイザの使用
|
||||
- T5XXLのfp8形式の使用
|
||||
- FP8/FP16混合学習(実験的機能)
|
||||
|
||||
`--timestep_sampling`オプションで、タイムステップ(0-1)のサンプリング方法を指定できます:
|
||||
</details>
|
||||
|
||||
- `sigma`:SD3と同様のシグマベース
|
||||
- `uniform`:一様ランダム
|
||||
- `sigmoid`:正規分布乱数のシグモイド(x-flux、AI-toolkitなどと同様)
|
||||
- `shift`:正規分布乱数のシグモイド値をシフト
|
||||
- `flux_shift`:解像度に応じて正規分布乱数のシグモイド値をシフト(FLUX.1 dev推論と同様)。この設定では`--discrete_flow_shift`は無視されます。
|
||||
### 6.2. Important FLUX.1 LoRA Training Settings / FLUX.1 LoRA学習の重要な設定
|
||||
|
||||
FLUX.1 training has many unknowns, and several settings can be specified with arguments:
|
||||
|
||||
#### タイムステップ分布の可視化
|
||||
#### Timestep Sampling Methods
|
||||
|
||||
`--timestep_sampling`, `--sigmoid_scale`, `--discrete_flow_shift` の組み合わせによって、学習中にサンプリングされるタイムステップの分布が変化します。以下にいくつかの例を示します。
|
||||
The `--timestep_sampling` option specifies how timesteps (0-1) are sampled:
|
||||
|
||||
* `--timestep_sampling shift` と `--discrete_flow_shift` の効果 (`--sigmoid_scale` はデフォルトの1.0):
|
||||

|
||||
- `sigma`: Sigma-based like SD3
|
||||
- `uniform`: Uniform random
|
||||
- `sigmoid`: Sigmoid of normal distribution random (similar to x-flux, AI-toolkit)
|
||||
- `shift`: Sigmoid value of normal distribution random with shift. The `--discrete_flow_shift` setting is used to shift the sigmoid value.
|
||||
- `flux_shift`: Shift sigmoid value of normal distribution random according to resolution (similar to FLUX.1 dev inference).
|
||||
|
||||
* `--timestep_sampling sigmoid` と `--timestep_sampling uniform` の比較 (`--discrete_flow_shift` は無視される):
|
||||

|
||||
`--discrete_flow_shift` only applies when `--timestep_sampling` is set to `shift`.
|
||||
|
||||
* `--timestep_sampling sigmoid` と `--sigmoid_scale` の効果 (`--discrete_flow_shift` は無視される):
|
||||

|
||||
#### Model Prediction Processing
|
||||
|
||||
#### AI Toolkit 設定との比較
|
||||
The `--model_prediction_type` option specifies how to interpret and process model predictions:
|
||||
|
||||
[Ostris氏のAI Toolkit](https://github.com/ostris/ai-toolkit) で使用されている設定は、概ね以下のオプションに相当すると考えられます。
|
||||
```
|
||||
--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0
|
||||
```
|
||||
- `raw`: Use as-is (similar to x-flux) **[Recommended]**
|
||||
- `additive`: Add to noise input
|
||||
- `sigma_scaled`: Apply sigma scaling (similar to SD3)
|
||||
|
||||
### 2.2 モデル予測の処理方法
|
||||
#### Recommended Settings
|
||||
|
||||
`--model_prediction_type`オプションで、モデルの予測をどのように解釈し処理するかを指定できます:
|
||||
|
||||
- `raw`:そのまま使用(x-fluxと同様)【推奨】
|
||||
- `additive`:ノイズ入力に加算
|
||||
- `sigma_scaled`:シグマスケーリングを適用(SD3と同様)
|
||||
|
||||
### 2.3 推奨設定
|
||||
|
||||
実験の結果、以下の設定が良好に動作することが確認されています:
|
||||
Based on experiments, the following settings work well:
|
||||
```
|
||||
--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0
|
||||
```
|
||||
|
||||
ガイダンススケールについて:FLUX.1 dev版は特定のガイダンススケール値で蒸留されていますが、学習時には`--guidance_scale 1.0`を指定してガイダンススケールを無効化することを推奨します。
|
||||
|
||||
|
||||
### 2.4 T5 Attention Mask の適用
|
||||
|
||||
`--apply_t5_attn_mask` オプションを指定すると、T5XXL Text Encoder の学習および推論時に Attention Mask が適用されます。
|
||||
|
||||
Attention Maskに対応した推論環境が限られるため、このオプションは推奨されません。
|
||||
|
||||
### 2.5 IP ノイズガンマ
|
||||
|
||||
`--ip_noise_gamma` および `--ip_noise_gamma_random_strength` オプションを使用することで、学習時に Input Perturbation ノイズのガンマ値を調整できます。詳細は Stable Diffusion 3 の学習オプションを参照してください。
|
||||
|
||||
### 2.6 LoRA-GGPO サポート
|
||||
|
||||
LoRA-GGPO (Gradient Group Proportion Optimizer) を使用できます。これは LoRA の学習を安定化させるための手法です。以下の `network_args` を指定して有効化します。ハイパーパラメータ (`ggpo_sigma`, `ggpo_beta`) は調整が必要です。
|
||||
|
||||
```bash
|
||||
--network_args "ggpo_sigma=0.03" "ggpo_beta=0.01"
|
||||
For Chroma models, the following settings are recommended:
|
||||
```
|
||||
TOMLファイルで指定する場合:
|
||||
```toml
|
||||
network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]
|
||||
--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 0.0
|
||||
```
|
||||
|
||||
### 2.7 Q/K/V 射影層の分割 [実験的機能]
|
||||
**About Guidance Scale**: FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `--guidance_scale 1.0` to disable guidance scale. Chroma requires `--guidance_scale 0.0` to disable guidance scale because it is not distilled.
|
||||
|
||||
`--network_args "split_qkv=True"` を指定することで、Attention層内の Q/K/V (および SingleStreamBlock の Text) 射影層を個別に分割し、それぞれに LoRA を適用できます。
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
**技術的詳細:**
|
||||
FLUX.1 の元々の実装では、Q/K/V (および Text) の射影層は一つに結合されています。ここに LoRA を適用すると、一つの大きな LoRA モジュールが適用されます。一方、Diffusers の実装ではこれらの射影層は分離されており、それぞれに小さな LoRA モジュールが適用されます。このオプションは後者の挙動を模倣します。
|
||||
保存される LoRA モデルの互換性は維持されますが、内部的には分割された LoRA の重みを結合して保存するため、ゼロ要素が多くなりモデルサイズが大きくなる可能性があります。`convert_flux_lora.py` スクリプトを使用して Diffusers (AI-Toolkit) 形式に変換すると、サイズが削減されます。
|
||||
FLUX.1の学習には多くの未知の点があり、いくつかの設定は引数で指定できます。詳細な説明とコマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
## 3. 各層に対するランク指定
|
||||
主要な設定オプション:
|
||||
- タイムステップのサンプリング方法(`--timestep_sampling`)
|
||||
- モデル予測の処理方法(`--model_prediction_type`)
|
||||
- 推奨設定の組み合わせ
|
||||
|
||||
FLUX.1の各層に対して異なるランク(network_dim)を指定できます。これにより、特定の層に対してLoRAの効果を強調したり、無効化したりできます。
|
||||
</details>
|
||||
|
||||
以下のnetwork_argsを指定することで、各層のランクを指定できます。0を指定するとその層にはLoRAが適用されません。
|
||||
### 6.3. Layer-specific Rank Configuration / 各層に対するランク指定
|
||||
|
||||
| network_args | 対象レイヤー |
|
||||
You can specify different ranks (network_dim) for each layer of FLUX.1. This allows you to emphasize or disable LoRA effects for specific layers.
|
||||
|
||||
Specify the following network_args to set ranks for each layer. Setting 0 disables LoRA for that layer:
|
||||
|
||||
| network_args | Target Layer |
|
||||
|--------------|--------------|
|
||||
| img_attn_dim | DoubleStreamBlockのimg_attn |
|
||||
| txt_attn_dim | DoubleStreamBlockのtxt_attn |
|
||||
| img_mlp_dim | DoubleStreamBlockのimg_mlp |
|
||||
| txt_mlp_dim | DoubleStreamBlockのtxt_mlp |
|
||||
| img_mod_dim | DoubleStreamBlockのimg_mod |
|
||||
| txt_mod_dim | DoubleStreamBlockのtxt_mod |
|
||||
| single_dim | SingleStreamBlockのlinear1とlinear2 |
|
||||
| single_mod_dim | SingleStreamBlockのmodulation |
|
||||
| img_attn_dim | DoubleStreamBlock img_attn |
|
||||
| txt_attn_dim | DoubleStreamBlock txt_attn |
|
||||
| img_mlp_dim | DoubleStreamBlock img_mlp |
|
||||
| txt_mlp_dim | DoubleStreamBlock txt_mlp |
|
||||
| img_mod_dim | DoubleStreamBlock img_mod |
|
||||
| txt_mod_dim | DoubleStreamBlock txt_mod |
|
||||
| single_dim | SingleStreamBlock linear1 and linear2 |
|
||||
| single_mod_dim | SingleStreamBlock modulation |
|
||||
|
||||
使用例:
|
||||
Example usage:
|
||||
```
|
||||
--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" "img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2"
|
||||
```
|
||||
|
||||
さらに、FLUXの条件付けレイヤーにLoRAを適用するには、network_argsに`in_dims`を指定します。5つの数値をカンマ区切りのリストとして指定する必要があります。
|
||||
To apply LoRA to FLUX conditioning layers, specify `in_dims` in network_args as a comma-separated list of 5 numbers:
|
||||
|
||||
例:
|
||||
```
|
||||
--network_args "in_dims=[4,2,2,2,4]"
|
||||
```
|
||||
|
||||
各数値は、`img_in`、`time_in`、`vector_in`、`guidance_in`、`txt_in`に対応します。上記の例では、すべての条件付けレイヤーにLoRAを適用し、`img_in`と`txt_in`のランクを4、その他のランクを2に設定しています。
|
||||
Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The example above applies LoRA to all conditioning layers with ranks of 4 for `img_in` and `txt_in`, and ranks of 2 for others.
|
||||
|
||||
0を指定するとそのレイヤーにはLoRAが適用されません。例えば、`[4,0,0,0,4]`は`img_in`と`txt_in`にのみLoRAを適用します。
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
## 4. 学習するブロックの指定
|
||||
FLUX.1の各層に対して異なるランク(network_dim)を指定できます。これにより、特定の層に対してLoRAの効果を強調したり、無効化したりできます。
|
||||
|
||||
FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_single_block_indices`を指定することで、学習するブロックを指定できます。インデックスは0ベースです。省略した場合のデフォルトはすべてのブロックを学習することです。
|
||||
詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
インデックスは、`0,1,5,8`のような整数のリストや、`0,1,4-5,7`のような整数の範囲として指定します。
|
||||
- double blocksの数は19なので、有効な範囲は0-18です
|
||||
- single blocksの数は38なので、有効な範囲は0-37です
|
||||
- `all`を指定するとすべてのブロックを学習します
|
||||
- `none`を指定するとブロックを学習しません
|
||||
</details>
|
||||
|
||||
使用例:
|
||||
### 6.4. Block Selection for Training / 学習するブロックの指定
|
||||
|
||||
You can specify which blocks to train using `train_double_block_indices` and `train_single_block_indices` in network_args. Indices are 0-based. Default is to train all blocks if omitted.
|
||||
|
||||
Specify indices as integer lists like `0,1,5,8` or integer ranges like `0,1,4-5,7`:
|
||||
- Double blocks: 19 blocks, valid range 0-18
|
||||
- Single blocks: 38 blocks, valid range 0-37
|
||||
- Specify `all` to train all blocks
|
||||
- Specify `none` to skip training blocks
|
||||
|
||||
Example usage:
|
||||
```
|
||||
--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37"
|
||||
```
|
||||
|
||||
または:
|
||||
Or:
|
||||
```
|
||||
--network_args "train_double_block_indices=none" "train_single_block_indices=10-15"
|
||||
```
|
||||
|
||||
`train_double_block_indices`または`train_single_block_indices`のどちらか一方だけを指定した場合、もう一方は通常通り学習されます。
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
## 5. Text Encoder LoRAのサポート
|
||||
FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_single_block_indices`を指定することで、学習するブロックを指定できます。
|
||||
|
||||
詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
</details>
|
||||
|
||||
### 6.5. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定
|
||||
|
||||
You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer.
|
||||
|
||||
These settings are specified via the `network_args` argument.
|
||||
|
||||
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
|
||||
* Example: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"`
|
||||
* This sets the rank to 4 for modules whose names contain `single` and contain `_modulation`, and to 8 for modules containing `img_attn`.
|
||||
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
|
||||
* Example: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"`
|
||||
* This sets the learning rate to `1e-3` for modules whose names contain `single_blocks` followed by a digit (`0` to `9`) or `10`, and to `2e-3` for modules whose names contain `double_blocks`.
|
||||
|
||||
**Notes:**
|
||||
|
||||
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
|
||||
* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied.
|
||||
* These settings are applied after the block-specific training settings (`train_double_block_indices`, `train_single_block_indices`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
正規表現を用いて、LoRAのモジュールごとにランク(dim)や学習率を指定することができます。これにより、層ごとの指定よりも柔軟できめ細やかな制御が可能になります。
|
||||
|
||||
これらの設定は `network_args` 引数で指定します。
|
||||
|
||||
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank` という形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"`
|
||||
* この例では、名前に `single` で始まり `_modulation` を含むモジュールのランクを4に、`img_attn` を含むモジュールのランクを8に設定します。
|
||||
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr` という形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"`
|
||||
* この例では、名前が `single_blocks` で始まり、後に数字(`0`から`9`)または`10`が続くモジュールの学習率を `1e-3` に、`double_blocks` を含むモジュールの学習率を `2e-3` に設定します。
|
||||
**注意点:**
|
||||
|
||||
* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。
|
||||
* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。
|
||||
* これらの設定は、ブロック指定(`train_double_block_indices`, `train_single_block_indices`)が適用された後に行われます。
|
||||
|
||||
</details>
|
||||
|
||||
### 6.6. Text Encoder LoRA Support / Text Encoder LoRAのサポート
|
||||
|
||||
FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA:
|
||||
|
||||
- To train only FLUX.1: specify `--network_train_unet_only`
|
||||
- To train FLUX.1 and CLIP-L: omit `--network_train_unet_only`
|
||||
- To train FLUX.1, CLIP-L, and T5XXL: omit `--network_train_unet_only` and add `--network_args "train_t5xxl=True"`
|
||||
|
||||
You can specify individual learning rates for CLIP-L and T5XXL with `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5` sets the first value for CLIP-L and the second for T5XXL. Specifying one value uses the same learning rate for both. If `--text_encoder_lr` is not specified, the default `--learning_rate` is used for both.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポートしています。
|
||||
|
||||
- FLUX.1のみをトレーニングする場合は、`--network_train_unet_only`を指定します
|
||||
- FLUX.1とCLIP-Lをトレーニングする場合は、`--network_train_unet_only`を省略します
|
||||
- FLUX.1、CLIP-L、T5XXLすべてをトレーニングする場合は、`--network_train_unet_only`を省略し、`--network_args "train_t5xxl=True"`を追加します
|
||||
詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
CLIP-LとT5XXLの学習率は、`--text_encoder_lr`で個別に指定できます。例えば、`--text_encoder_lr 1e-4 1e-5`とすると、最初の値はCLIP-Lの学習率、2番目の値はT5XXLの学習率になります。1つだけ指定すると、CLIP-LとT5XXLの学習率は同じになります。`--text_encoder_lr`を指定しない場合、デフォルトの学習率`--learning_rate`が両方に使用されます。
|
||||
</details>
|
||||
|
||||
## 6. マルチ解像度トレーニング
|
||||
### 6.7. Multi-Resolution Training / マルチ解像度トレーニング
|
||||
|
||||
データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。
|
||||
You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution.
|
||||
|
||||
設定ファイルの例:
|
||||
Configuration file example:
|
||||
```toml
|
||||
[general]
|
||||
# 共通設定をここで定義
|
||||
# Common settings
|
||||
flip_aug = true
|
||||
color_aug = false
|
||||
keep_tokens_separator= "|||"
|
||||
@@ -425,85 +514,152 @@ caption_tag_dropout_rate = 0
|
||||
caption_extension = ".txt"
|
||||
|
||||
[[datasets]]
|
||||
# 最初の解像度の設定
|
||||
# First resolution settings
|
||||
batch_size = 2
|
||||
enable_bucket = true
|
||||
resolution = [1024, 1024]
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "画像ディレクトリへのパス"
|
||||
image_dir = "path/to/image/directory"
|
||||
num_repeats = 1
|
||||
|
||||
[[datasets]]
|
||||
# 2番目の解像度の設定
|
||||
# Second resolution settings
|
||||
batch_size = 3
|
||||
enable_bucket = true
|
||||
resolution = [768, 768]
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "画像ディレクトリへのパス"
|
||||
num_repeats = 1
|
||||
|
||||
[[datasets]]
|
||||
# 3番目の解像度の設定
|
||||
batch_size = 4
|
||||
enable_bucket = true
|
||||
resolution = [512, 512]
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "画像ディレクトリへのパス"
|
||||
image_dir = "path/to/image/directory"
|
||||
num_repeats = 1
|
||||
```
|
||||
|
||||
各解像度セクションの`[[datasets.subsets]]`部分は、データセットディレクトリを定義します。各解像度に対して同じディレクトリを指定してください。</details>
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
## 7. 検証 (Validation)
|
||||
データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。
|
||||
|
||||
学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。
|
||||
設定ファイルの例は英語のドキュメントを参照してください。
|
||||
|
||||
検証を設定するには、データセット設定 TOML ファイルに `[validation]` セクションを追加します。設定方法は学習データセットと同様ですが、`num_repeats` は通常 1 に設定します。
|
||||
</details>
|
||||
|
||||
### 6.8. Validation / 検証
|
||||
|
||||
You can calculate validation loss during training using a validation dataset to evaluate model generalization performance.
|
||||
|
||||
To set up validation, add a `[validation]` section to your dataset configuration TOML file. Configuration is similar to training datasets, but `num_repeats` is usually set to 1.
|
||||
|
||||
```toml
|
||||
# ... (学習データセットの設定) ...
|
||||
# ... (training dataset configuration) ...
|
||||
|
||||
[validation]
|
||||
batch_size = 1
|
||||
enable_bucket = true
|
||||
resolution = [1024, 1024] # 検証に使用する解像度
|
||||
resolution = [1024, 1024] # Resolution for validation
|
||||
|
||||
[[validation.subsets]]
|
||||
image_dir = "検証用画像ディレクトリへのパス"
|
||||
image_dir = "path/to/validation/images"
|
||||
num_repeats = 1
|
||||
caption_extension = ".txt"
|
||||
# ... 他の検証データセット固有の設定 ...
|
||||
# ... other validation dataset settings ...
|
||||
```
|
||||
|
||||
**注意点:**
|
||||
**Notes:**
|
||||
|
||||
* 検証損失の計算は、固定されたタイムステップサンプリングと乱数シードで行われます。これにより、ランダム性による損失の変動を抑え、より安定した評価が可能になります。
|
||||
* 現在のところ、`--blocks_to_swap` オプションを使用している場合、または Schedule-Free オプティマイザ (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`) を使用している場合は、検証損失はサポートされていません。
|
||||
* Validation loss calculation uses fixed timestep sampling and random seeds to reduce loss variation due to randomness for more stable evaluation.
|
||||
* Currently, validation loss is not supported when using `--blocks_to_swap` or Schedule-Free optimizers (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`).
|
||||
|
||||
## 8. データセット関連の追加オプション
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
### 8.1 リサイズ時の補間方法指定
|
||||
学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。
|
||||
|
||||
データセットの画像を学習解像度にリサイズする際の補間方法を指定できます。データセット設定 TOML ファイルの `[[datasets]]` セクションまたは `[general]` セクションで `interpolation_type` を指定します。
|
||||
詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
利用可能な値: `bicubic` (デフォルト), `bilinear`, `lanczos`, `nearest`, `area`
|
||||
</details>
|
||||
|
||||
## 7. Additional Options / 追加オプション
|
||||
|
||||
### 7.1. Other FLUX.1-specific Options / その他のFLUX.1特有のオプション
|
||||
|
||||
- **T5 Attention Mask Application**: Specify `--apply_t5_attn_mask` to apply attention masks during T5XXL Text Encoder training and inference. Not recommended due to limited inference environment support. **For Chroma models, this option is required.**
|
||||
|
||||
- **IP Noise Gamma**: Use `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` to adjust Input Perturbation noise gamma values during training. See Stable Diffusion 3 training options for details.
|
||||
|
||||
- **LoRA-GGPO Support**: Use LoRA-GGPO (Gradient Group Proportion Optimizer) to stabilize LoRA training:
|
||||
```bash
|
||||
--network_args "ggpo_sigma=0.03" "ggpo_beta=0.01"
|
||||
```
|
||||
|
||||
- **Q/K/V Projection Layer Splitting [Experimental]**: Specify `--network_args "split_qkv=True"` to individually split and apply LoRA to Q/K/V (and SingleStreamBlock Text) projection layers within Attention layers.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
その他のFLUX.1特有のオプション:
|
||||
- T5 Attention Maskの適用(Chromaモデルでは必須)
|
||||
- IPノイズガンマ
|
||||
- LoRA-GGPOサポート
|
||||
- Q/K/V射影層の分割(実験的機能)
|
||||
|
||||
詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
</details>
|
||||
|
||||
### 7.2. Dataset-related Additional Options / データセット関連の追加オプション
|
||||
|
||||
#### Interpolation Method for Resizing
|
||||
|
||||
You can specify the interpolation method when resizing dataset images to training resolution. Specify `interpolation_type` in the `[[datasets]]` or `[general]` section of the dataset configuration TOML file.
|
||||
|
||||
Available values: `bicubic` (default), `bilinear`, `lanczos`, `nearest`, `area`
|
||||
|
||||
```toml
|
||||
[[datasets]]
|
||||
resolution = [1024, 1024]
|
||||
enable_bucket = true
|
||||
interpolation_type = "lanczos" # 例: Lanczos補間を使用
|
||||
interpolation_type = "lanczos" # Example: Use Lanczos interpolation
|
||||
# ...
|
||||
```
|
||||
|
||||
## 9. 関連ツール
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`flux_train_network.py` で学習したモデルや、学習プロセスに役立つ関連スクリプトが提供されています。
|
||||
データセットの画像を学習解像度にリサイズする際の補間方法を指定できます。
|
||||
|
||||
* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出します。
|
||||
* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など、他の形式に変換します。Q/K/V分割オプションで学習した場合、このスクリプトで変換するとモデルサイズを削減できます。
|
||||
* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージします。
|
||||
* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するためのシンプルな推論スクリプトです。
|
||||
設定方法とオプションの詳細は英語のドキュメントを参照してください。
|
||||
|
||||
</details>
|
||||
|
||||
## 8. Related Tools / 関連ツール
|
||||
|
||||
Several related scripts are provided for models trained with `flux_train_network.py` and to assist with the training process:
|
||||
|
||||
* **`networks/flux_extract_lora.py`**: Extracts LoRA models from the difference between trained and base models.
|
||||
* **`convert_flux_lora.py`**: Converts trained LoRA models to other formats like Diffusers (AI-Toolkit) format. When trained with Q/K/V split option, converting with this script can reduce model size.
|
||||
* **`networks/flux_merge_lora.py`**: Merges trained LoRA models into FLUX.1 base models.
|
||||
* **`flux_minimal_inference.py`**: Simple inference script for generating images with trained LoRA models. You can specify `flux` or `chroma` with the `--model_type` argument.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`flux_train_network.py` で学習したモデルや、学習プロセスに役立つ関連スクリプトが提供されています:
|
||||
|
||||
* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出。
|
||||
* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など他の形式に変換。
|
||||
* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージ。
|
||||
* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するシンプルな推論スクリプト。
|
||||
`--model_type` 引数で `flux` または `chroma` を指定できます。
|
||||
|
||||
</details>
|
||||
|
||||
## 9. Others / その他
|
||||
|
||||
`flux_train_network.py` includes many features common with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these features, refer to the [`train_network.py` guide](train_network.md#5-other-features--その他の機能) or the script help (`python flux_train_network.py --help`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`flux_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python flux_train_network.py --help`) を参照してください。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -78,16 +78,19 @@ def denoise(
|
||||
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
cfg_scale: Optional[float] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
# prepare classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
||||
do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0)
|
||||
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
if do_cfg:
|
||||
print("Using classifier free guidance")
|
||||
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
||||
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
||||
b_txt = torch.cat([neg_txt, txt], dim=0)
|
||||
b_vec = torch.cat([neg_vec, vec], dim=0)
|
||||
b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None
|
||||
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
||||
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
else:
|
||||
@@ -103,24 +106,29 @@ def denoise(
|
||||
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
if do_cfg:
|
||||
b_img = torch.cat([img, img], dim=0)
|
||||
else:
|
||||
b_img = img
|
||||
|
||||
y_input = b_vec
|
||||
|
||||
mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0])
|
||||
|
||||
pred = model(
|
||||
img=b_img,
|
||||
img_ids=b_img_ids,
|
||||
txt=b_txt,
|
||||
txt_ids=b_txt_ids,
|
||||
y=b_vec,
|
||||
y=y_input,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=b_t5_attn_mask,
|
||||
mod_vectors=mod_vectors,
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
if neg_txt is not None and neg_vec is not None:
|
||||
if do_cfg:
|
||||
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
||||
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
||||
|
||||
@@ -134,7 +142,7 @@ def do_sample(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
l_pooled: torch.Tensor,
|
||||
l_pooled: Optional[torch.Tensor],
|
||||
t5_out: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
num_steps: int,
|
||||
@@ -192,7 +200,7 @@ def do_sample(
|
||||
|
||||
def generate_image(
|
||||
model,
|
||||
clip_l: CLIPTextModel,
|
||||
clip_l: Optional[CLIPTextModel],
|
||||
t5xxl,
|
||||
ae,
|
||||
prompt: str,
|
||||
@@ -231,7 +239,7 @@ def generate_image(
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
||||
|
||||
# prepare fp8 models
|
||||
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
||||
if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
||||
clip_l.to(clip_l_dtype) # fp8
|
||||
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
||||
@@ -267,18 +275,22 @@ def generate_image(
|
||||
|
||||
# prepare embeddings
|
||||
logger.info("Encoding prompts...")
|
||||
clip_l = clip_l.to(device)
|
||||
if clip_l is not None:
|
||||
clip_l = clip_l.to(device)
|
||||
t5xxl = t5xxl.to(device)
|
||||
|
||||
def encode(prpt: str):
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
with torch.no_grad():
|
||||
if is_fp8(clip_l_dtype):
|
||||
with accelerator.autocast():
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
if clip_l is not None:
|
||||
if is_fp8(clip_l_dtype):
|
||||
with accelerator.autocast():
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||
l_pooled = None
|
||||
|
||||
if is_fp8(t5xxl_dtype):
|
||||
with accelerator.autocast():
|
||||
@@ -288,7 +300,7 @@ def generate_image(
|
||||
else:
|
||||
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
||||
|
||||
@@ -299,13 +311,14 @@ def generate_image(
|
||||
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
||||
|
||||
# NaN check
|
||||
if torch.isnan(l_pooled).any():
|
||||
if l_pooled is not None and torch.isnan(l_pooled).any():
|
||||
raise ValueError("NaN in l_pooled")
|
||||
if torch.isnan(t5_out).any():
|
||||
raise ValueError("NaN in t5_out")
|
||||
|
||||
if args.offload:
|
||||
clip_l = clip_l.cpu()
|
||||
if clip_l is not None:
|
||||
clip_l = clip_l.cpu()
|
||||
t5xxl = t5xxl.cpu()
|
||||
# del clip_l, t5xxl
|
||||
device_utils.clean_memory()
|
||||
@@ -318,6 +331,7 @@ def generate_image(
|
||||
|
||||
img_ids = img_ids.to(device)
|
||||
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||
neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None
|
||||
|
||||
x = do_sample(
|
||||
accelerator,
|
||||
@@ -385,6 +399,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use")
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--ae", type=str, required=False)
|
||||
@@ -438,10 +453,13 @@ if __name__ == "__main__":
|
||||
else:
|
||||
accelerator = None
|
||||
|
||||
# load clip_l
|
||||
logger.info(f"Loading clip_l from {args.clip_l}...")
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
||||
clip_l.eval()
|
||||
# load clip_l (skip for chroma model)
|
||||
if args.model_type == "flux":
|
||||
logger.info(f"Loading clip_l from {args.clip_l}...")
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
||||
clip_l.eval()
|
||||
else:
|
||||
clip_l = None
|
||||
|
||||
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
||||
@@ -453,7 +471,7 @@ if __name__ == "__main__":
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
# DiT
|
||||
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
||||
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
|
||||
@@ -271,7 +271,7 @@ def train(args):
|
||||
|
||||
# load FLUX
|
||||
_, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
|
||||
@@ -68,6 +68,11 @@ def train(args):
|
||||
if not args.skip_cache_check:
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
if args.model_type != "flux":
|
||||
raise ValueError(
|
||||
f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。"
|
||||
)
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
@@ -259,7 +264,7 @@ def train(args):
|
||||
|
||||
# load FLUX
|
||||
is_schnell, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
|
||||
)
|
||||
flux.requires_grad_(False)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
self.model_type: Optional[str] = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
@@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
self.model_type = args.model_type # "flux" or "chroma"
|
||||
if self.model_type != "chroma":
|
||||
self.use_clip_l = True
|
||||
else:
|
||||
self.use_clip_l = False # Chroma does not use CLIP-L
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||||
|
||||
@@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# prepare CLIP-L/T5XXL training flags
|
||||
self.train_clip_l = not args.network_train_unet_only
|
||||
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
@@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
self.is_schnell, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
_, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
"cpu",
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
model_type=self.model_type,
|
||||
)
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
@@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
if self.use_clip_l:
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
else:
|
||||
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
|
||||
clip_l.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
@@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
|
||||
return model_version, [clip_l, t5xxl], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
|
||||
# Instead, we analyze the checkpoint state to determine if it is schnell.
|
||||
if args.model_type != "chroma":
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
else:
|
||||
is_schnell = False
|
||||
self.is_schnell = is_schnell
|
||||
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if is_schnell:
|
||||
if self.is_schnell:
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
@@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
@@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
# return
|
||||
|
||||
"""
|
||||
class FluxUpperLowerWrapper(torch.nn.Module):
|
||||
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
||||
super().__init__()
|
||||
self.flux_upper = flux_upper
|
||||
self.flux_lower = flux_lower
|
||||
self.target_device = device
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
pass
|
||||
|
||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
||||
self.flux_lower.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_upper.to(self.target_device)
|
||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_lower.to(self.target_device)
|
||||
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
"""
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
@@ -366,7 +340,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# ensure guidance_scale in args is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
# get modulation vectors for Chroma
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
@@ -374,13 +351,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t.requires_grad_(True)
|
||||
img_ids.requires_grad_(True)
|
||||
guidance_vec.requires_grad_(True)
|
||||
if mod_vectors is not None:
|
||||
mod_vectors.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -393,6 +372,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
mod_vectors=mod_vectors,
|
||||
)
|
||||
return model_pred
|
||||
|
||||
@@ -405,6 +385,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps,
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
mod_vectors=mod_vectors,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
@@ -436,6 +417,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
@@ -454,9 +436,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||
if self.model_type != "chroma":
|
||||
model_description = "schnell" if self.is_schnell else "dev"
|
||||
else:
|
||||
model_description = "chroma"
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_model_type"] = args.model_type
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
|
||||
744
library/chroma_models.py
Normal file
744
library/chroma_models.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# copy from the official repo: https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py
|
||||
# and modified
|
||||
# licensed under Apache License 2.0
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux
|
||||
from . import custom_offloading_utils
|
||||
|
||||
|
||||
def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks):
|
||||
"""
|
||||
Distributes slices of the tensor into the block_dict as ModulationOut objects.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
|
||||
"""
|
||||
batch_size, vectors, dim = tensor.shape
|
||||
|
||||
block_dict = {}
|
||||
|
||||
# HARD CODED VALUES! lookup table for the generated vectors
|
||||
# TODO: move this into chroma config!
|
||||
# Add 38 single mod blocks
|
||||
for i in range(depth_single_blocks):
|
||||
key = f"single_blocks.{i}.modulation.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 image double blocks
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.img_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add 19 text double blocks
|
||||
for i in range(depth_double_blocks):
|
||||
key = f"double_blocks.{i}.txt_mod.lin"
|
||||
block_dict[key] = None
|
||||
|
||||
# Add the final layer
|
||||
block_dict["final_layer.adaLN_modulation.1"] = None
|
||||
# 6.2b version
|
||||
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
|
||||
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None
|
||||
|
||||
idx = 0 # Index to keep track of the vector slices
|
||||
|
||||
for key in block_dict.keys():
|
||||
if "single_blocks" in key:
|
||||
# Single block: 1 ModulationOut
|
||||
block_dict[key] = ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors
|
||||
|
||||
elif "img_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "txt_mod" in key:
|
||||
# Double block: List of 2 ModulationOut
|
||||
double_block = []
|
||||
for _ in range(2): # Create 2 ModulationOut objects
|
||||
double_block.append(
|
||||
ModulationOut(
|
||||
shift=tensor[:, idx : idx + 1, :],
|
||||
scale=tensor[:, idx + 1 : idx + 2, :],
|
||||
gate=tensor[:, idx + 2 : idx + 3, :],
|
||||
)
|
||||
)
|
||||
idx += 3 # Advance by 3 vectors per ModulationOut
|
||||
block_dict[key] = double_block
|
||||
|
||||
elif "final_layer" in key:
|
||||
# Final layer: 1 ModulationOut
|
||||
block_dict[key] = [
|
||||
tensor[:, idx : idx + 1, :],
|
||||
tensor[:, idx + 1 : idx + 2, :],
|
||||
]
|
||||
idx += 2 # Advance by 3 vectors
|
||||
|
||||
return block_dict
|
||||
|
||||
|
||||
class Approximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
for layer in self.layers:
|
||||
layer.enable_gradient_checkpointing()
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
for layer in self.layers:
|
||||
layer.disable_gradient_checkpointing()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
def _modulation_shift_scale_fn(x, scale, shift):
|
||||
return (1 + scale) * x + shift
|
||||
|
||||
|
||||
def _modulation_gate_fn(x, gate, gate_params):
|
||||
return x + gate * gate_params
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def modulation_gate_fn(self, x, gate, gate_params):
|
||||
return _modulation_gate_fn(x, gate, gate_params)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
pe: list[Tensor],
|
||||
distill_vec: list[ModulationOut],
|
||||
txt_seq_len: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
# replaced with compiled fn
|
||||
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
# replaced with compiled fn
|
||||
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention: we split the batch into each element
|
||||
max_txt_len = torch.max(txt_seq_len).item()
|
||||
img_len = img_q.shape[-2] # max 64
|
||||
txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors
|
||||
txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0))
|
||||
txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0))
|
||||
img_q = list(torch.chunk(img_q, img_q.shape[0], dim=0))
|
||||
img_k = list(torch.chunk(img_k, img_k.shape[0], dim=0))
|
||||
img_v = list(torch.chunk(img_v, img_v.shape[0], dim=0))
|
||||
txt_attn = []
|
||||
img_attn = []
|
||||
for i in range(txt.shape[0]):
|
||||
txt_q[i] = txt_q[i][:, :, : txt_seq_len[i]]
|
||||
q = torch.cat((img_q[i], txt_q[i]), dim=2)
|
||||
txt_q[i] = None
|
||||
img_q[i] = None
|
||||
|
||||
txt_k[i] = txt_k[i][:, :, : txt_seq_len[i]]
|
||||
k = torch.cat((img_k[i], txt_k[i]), dim=2)
|
||||
txt_k[i] = None
|
||||
img_k[i] = None
|
||||
|
||||
txt_v[i] = txt_v[i][:, :, : txt_seq_len[i]]
|
||||
v = torch.cat((img_v[i], txt_v[i]), dim=2)
|
||||
txt_v[i] = None
|
||||
img_v[i] = None
|
||||
|
||||
attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D)
|
||||
del q, k, v
|
||||
img_attn_i = attn[:, :img_len, :]
|
||||
txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device)
|
||||
txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :]
|
||||
del attn
|
||||
txt_attn.append(txt_attn_i)
|
||||
img_attn.append(img_attn_i)
|
||||
|
||||
txt_attn = torch.cat(txt_attn, dim=0)
|
||||
img_attn = torch.cat(img_attn, dim=0)
|
||||
|
||||
# q = torch.cat((txt_q, img_q), dim=2)
|
||||
# k = torch.cat((txt_k, img_k), dim=2)
|
||||
# v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
# attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
# txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img blocks
|
||||
# replaced with compiled fn
|
||||
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
del img_attn, img_mod1
|
||||
img = self.modulation_gate_fn(
|
||||
img,
|
||||
img_mod2.gate,
|
||||
self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)),
|
||||
)
|
||||
del img_mod2
|
||||
|
||||
# calculate the txt blocks
|
||||
# replaced with compiled fn
|
||||
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
del txt_attn, txt_mod1
|
||||
txt = self.modulation_gate_fn(
|
||||
txt,
|
||||
txt_mod2.gate,
|
||||
self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)),
|
||||
)
|
||||
del txt_mod2
|
||||
|
||||
return img, txt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
pe: Tensor,
|
||||
distill_vec: list[ModulationOut],
|
||||
txt_seq_len: Tensor,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, txt_seq_len, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(img, txt, pe, distill_vec, txt_seq_len)
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def modulation_gate_fn(self, x, gate, gate_params):
|
||||
return _modulation_gate_fn(x, gate, gate_params)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward(self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor:
|
||||
mod = distill_vec
|
||||
# replaced with compiled fn
|
||||
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
del x_mod
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# # compute attention
|
||||
# attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
|
||||
# compute attention: we split the batch into each element
|
||||
max_txt_len = torch.max(txt_seq_len).item()
|
||||
img_len = q.shape[-2] - max_txt_len
|
||||
q = list(torch.chunk(q, q.shape[0], dim=0))
|
||||
k = list(torch.chunk(k, k.shape[0], dim=0))
|
||||
v = list(torch.chunk(v, v.shape[0], dim=0))
|
||||
attn = []
|
||||
for i in range(x.size(0)):
|
||||
q[i] = q[i][:, :, : img_len + txt_seq_len[i]]
|
||||
k[i] = k[i][:, :, : img_len + txt_seq_len[i]]
|
||||
v[i] = v[i][:, :, : img_len + txt_seq_len[i]]
|
||||
attn_trimmed = attention(q[i], k[i], v[i], pe=pe[i : i + 1, :, : img_len + txt_seq_len[i]], attn_mask=None)
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
|
||||
attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device)
|
||||
attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed
|
||||
del attn_trimmed
|
||||
attn.append(attn_i)
|
||||
|
||||
attn = torch.cat(attn, dim=0)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
del attn, mlp
|
||||
# replaced with compiled fn
|
||||
# return x + mod.gate * output
|
||||
return self.modulation_gate_fn(x, mod.gate, output)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, pe, distill_vec, txt_seq_len)
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def modulation_shift_scale_fn(self, x, scale, shift):
|
||||
return _modulation_shift_scale_fn(x, scale, shift)
|
||||
|
||||
def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor:
|
||||
shift, scale = distill_vec
|
||||
shift = shift.squeeze(1)
|
||||
scale = scale.squeeze(1)
|
||||
# replaced with compiled fn
|
||||
# x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.modulation_shift_scale_fn(self.norm_final(x), scale[:, None, :], shift[:, None, :])
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChromaParams:
|
||||
in_channels: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
approximator_in_dim: int
|
||||
approximator_depth: int
|
||||
approximator_hidden_size: int
|
||||
_use_compiled: bool
|
||||
|
||||
|
||||
chroma_params = ChromaParams(
|
||||
in_channels=64,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
approximator_in_dim=64,
|
||||
approximator_depth=5,
|
||||
approximator_hidden_size=5120,
|
||||
_use_compiled=False,
|
||||
)
|
||||
|
||||
|
||||
def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
|
||||
"""
|
||||
Modifies attention mask to allow attention to a few extra padding tokens.
|
||||
|
||||
Args:
|
||||
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
|
||||
max_seq_length: Maximum sequence length of the model
|
||||
num_extra_padding: Number of padding tokens to unmask
|
||||
|
||||
Returns:
|
||||
Modified mask
|
||||
"""
|
||||
# Get the actual sequence length from the mask
|
||||
seq_length = mask.sum(dim=-1)
|
||||
batch_size = mask.shape[0]
|
||||
|
||||
modified_mask = mask.clone()
|
||||
|
||||
for i in range(batch_size):
|
||||
current_seq_len = int(seq_length[i].item())
|
||||
|
||||
# Only add extra padding tokens if there's room
|
||||
if current_seq_len < max_seq_length:
|
||||
# Calculate how many padding tokens we can unmask
|
||||
available_padding = max_seq_length - current_seq_len
|
||||
tokens_to_unmask = min(num_extra_padding, available_padding)
|
||||
|
||||
# Unmask the specified number of padding tokens right after the sequence
|
||||
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1
|
||||
|
||||
return modified_mask
|
||||
|
||||
|
||||
class Chroma(Flux):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ChromaParams):
|
||||
nn.Module.__init__(self)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
|
||||
# TODO: need proper mapping for this approximator output!
|
||||
# currently the mapping is hardcoded in distribute_modulations function
|
||||
self.distilled_guidance_layer = Approximator(
|
||||
params.approximator_in_dim,
|
||||
self.hidden_size,
|
||||
params.approximator_hidden_size,
|
||||
params.approximator_depth,
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(
|
||||
self.hidden_size,
|
||||
1,
|
||||
self.out_channels,
|
||||
)
|
||||
|
||||
# TODO: move this hardcoded value to config
|
||||
# single layer has 3 modulation vectors
|
||||
# double layer has 6 modulation vectors for each expert
|
||||
# final layer has 2 modulation vectors
|
||||
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
|
||||
self.depth_single_blocks = params.depth_single_blocks
|
||||
self.depth_double_blocks = params.depth
|
||||
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
|
||||
self.register_buffer(
|
||||
"mod_index",
|
||||
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
|
||||
persistent=False,
|
||||
)
|
||||
self.approximator_in_dim = params.approximator_in_dim
|
||||
|
||||
self.blocks_to_swap = None
|
||||
self.offloader_double = None
|
||||
self.offloader_single = None
|
||||
self.num_double_blocks = len(self.double_blocks)
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
# Initialize properties required by Flux parent class
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def get_model_type(self) -> str:
|
||||
return "chroma"
|
||||
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
self.cpu_offload_checkpointing = cpu_offload
|
||||
|
||||
self.distilled_guidance_layer.enable_gradient_checkpointing()
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print(f"Chroma: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
self.distilled_guidance_layer.disable_gradient_checkpointing()
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("Chroma: Gradient checkpointing disabled.")
|
||||
|
||||
def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
# We extract this logic from forward to clarify the propagation of the gradients
|
||||
# original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195
|
||||
|
||||
# print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}")
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
|
||||
# get all modulation index
|
||||
modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2)
|
||||
# we need to broadcast the modulation index here so each batch has all of the index
|
||||
modulation_index = modulation_index.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
# and we need to broadcast timestep and guidance along too
|
||||
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
|
||||
# then and only then we could concatenate it together
|
||||
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)
|
||||
|
||||
mod_vectors = self.distilled_guidance_layer(input_vec)
|
||||
return mod_vectors
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
block_controlnet_hidden_states=None,
|
||||
block_controlnet_single_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
attn_padding: int = 1,
|
||||
mod_vectors: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# print(
|
||||
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
|
||||
# )
|
||||
# print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}")
|
||||
# print(f"timesteps: {timesteps}, guidance: {guidance}")
|
||||
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if mod_vectors is None: # fallback to the original logic
|
||||
with torch.no_grad():
|
||||
mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0])
|
||||
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)
|
||||
|
||||
# calculate text length for each batch instead of masking
|
||||
txt_emb_len = txt.shape[1]
|
||||
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, )
|
||||
txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len)
|
||||
max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch
|
||||
# print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}")
|
||||
|
||||
# trim txt embedding to the text length
|
||||
txt = txt[:, :max_txt_len, :]
|
||||
|
||||
# create positional encoding for the text and image
|
||||
ids = torch.cat((img_ids, txt_ids[:, :max_txt_len]), dim=1) # reverse order of ids for faster attention
|
||||
pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2
|
||||
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.wait_for_block(i)
|
||||
|
||||
# the guidance replaced by FFN output
|
||||
img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin")
|
||||
txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin")
|
||||
double_mod = [img_mod, txt_mod]
|
||||
del img_mod, txt_mod
|
||||
|
||||
img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len)
|
||||
del double_mod
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.submit_move_blocks(self.double_blocks, i)
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
del txt
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_single.wait_for_block(i)
|
||||
|
||||
single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin")
|
||||
img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len)
|
||||
del single_mod
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_single.submit_move_blocks(self.single_blocks, i)
|
||||
|
||||
img = img[:, :-max_txt_len, ...]
|
||||
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
|
||||
img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -930,6 +930,9 @@ class Flux(nn.Module):
|
||||
self.num_double_blocks = len(self.double_blocks)
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
def get_model_type(self) -> str:
|
||||
return "flux"
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@@ -1006,6 +1009,9 @@ class Flux(nn.Module):
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||
|
||||
def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
return None # FLUX.1 does not use mod_vectors, but Chroma does.
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
@@ -1018,6 +1024,7 @@ class Flux(nn.Module):
|
||||
block_controlnet_single_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
mod_vectors: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -1169,7 +1176,7 @@ class ControlNetFlux(nn.Module):
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1320,174 +1327,3 @@ class ControlNetFlux(nn.Module):
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
|
||||
|
||||
return controlnet_block_samples, controlnet_single_block_samples
|
||||
|
||||
|
||||
"""
|
||||
class FluxUpper(nn.Module):
|
||||
""
|
||||
Transformer model for flow matching on sequences.
|
||||
""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
self.time_in.enable_gradient_checkpointing()
|
||||
self.vector_in.enable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.time_in.disable_gradient_checkpointing()
|
||||
self.vector_in.disable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.disable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
return img, txt, vec, pe
|
||||
|
||||
|
||||
class FluxLower(nn.Module):
|
||||
""
|
||||
Transformer model for flow matching on sequences.
|
||||
""
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.out_channels = params.in_channels
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
vec: Tensor | None = None,
|
||||
pe: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
"""
|
||||
|
||||
@@ -154,9 +154,8 @@ def sample_image_inference(
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5)
|
||||
cfg_scale = prompt_dict.get("scale", 1.0)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -242,7 +241,7 @@ def sample_image_inference(
|
||||
dtype=weight_dtype,
|
||||
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
||||
)
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
||||
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||
|
||||
@@ -403,8 +402,8 @@ def denoise(
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timesteps=t_vec.repeat(2),
|
||||
guidance=guidance_vec.repeat(2),
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
@@ -680,3 +679,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
choices=["flux", "chroma"],
|
||||
default="flux",
|
||||
help="Model type to use for training / トレーニングに使用するモデルタイプ:flux or chroma (default: flux)",
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from library.utils import load_safetensors
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
MODEL_VERSION_CHROMA = "chroma"
|
||||
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
@@ -92,50 +93,84 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
model_type: str = "flux",
|
||||
) -> Tuple[bool, flux_models.Flux]:
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
if model_type == "flux":
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
|
||||
# build model
|
||||
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
||||
with torch.device("meta"):
|
||||
params = flux_models.configs[name].params
|
||||
# build model
|
||||
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
||||
with torch.device("meta"):
|
||||
params = flux_models.configs[name].params
|
||||
|
||||
# set the number of blocks
|
||||
if params.depth != num_double_blocks:
|
||||
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
||||
params = replace(params, depth=num_double_blocks)
|
||||
if params.depth_single_blocks != num_single_blocks:
|
||||
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
||||
params = replace(params, depth_single_blocks=num_single_blocks)
|
||||
# set the number of blocks
|
||||
if params.depth != num_double_blocks:
|
||||
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
||||
params = replace(params, depth=num_double_blocks)
|
||||
if params.depth_single_blocks != num_single_blocks:
|
||||
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
||||
params = replace(params, depth_single_blocks=num_single_blocks)
|
||||
|
||||
model = flux_models.Flux(params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
model = flux_models.Flux(params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = {}
|
||||
for ckpt_path in ckpt_paths:
|
||||
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = {}
|
||||
for ckpt_path in ckpt_paths:
|
||||
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
|
||||
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
logger.info("Converting Diffusers to BFL")
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
logger.info("Converting Diffusers to BFL")
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return is_schnell, model
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return is_schnell, model
|
||||
|
||||
elif model_type == "chroma":
|
||||
from . import chroma_models
|
||||
|
||||
# build model
|
||||
logger.info("Building Chroma model")
|
||||
with torch.device("meta"):
|
||||
model = chroma_models.Chroma(chroma_models.chroma_params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Chroma: {info}")
|
||||
is_schnell = False # Chroma is not schnell
|
||||
return is_schnell, model
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
|
||||
|
||||
|
||||
def load_ae(
|
||||
@@ -166,7 +201,43 @@ def load_controlnet(
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = controlnet.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded ControlNet: {info}")
|
||||
return controlnet
|
||||
return controlnet
|
||||
|
||||
|
||||
def dummy_clip_l() -> torch.nn.Module:
|
||||
"""
|
||||
Returns a dummy CLIP-L model with the output shape of (N, 77, 768).
|
||||
"""
|
||||
return DummyCLIPL()
|
||||
|
||||
|
||||
class DummyTextModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embeddings = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
|
||||
class DummyCLIPL(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
|
||||
self.text_model = DummyTextModel()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.dummy_param.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.dummy_param.dtype
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Returns a dummy output with the shape of (N, 77, 768).
|
||||
"""
|
||||
batch_size = args[0].shape[0] if args else 1
|
||||
return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)}
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
|
||||
@@ -60,6 +60,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
||||
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_SCHNELL = "flux-1-schnell"
|
||||
ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
@@ -71,6 +73,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
@@ -129,7 +132,7 @@ def build_metadata(
|
||||
lumina: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
sd3: only supports "m", flux: only supports "dev"
|
||||
sd3: only supports "m", flux: supports "dev", "schnell" or "chroma"
|
||||
"""
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
@@ -148,6 +151,10 @@ def build_metadata(
|
||||
elif flux is not None:
|
||||
if flux == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
elif flux == "schnell":
|
||||
arch = ARCH_FLUX_1_SCHNELL
|
||||
elif flux == "chroma":
|
||||
arch = ARCH_FLUX_1_CHROMA
|
||||
else:
|
||||
arch = ARCH_FLUX_1_UNKNOWN
|
||||
elif lumina is not None:
|
||||
@@ -175,7 +182,10 @@ def build_metadata(
|
||||
|
||||
if flux is not None:
|
||||
# Flux
|
||||
impl = IMPL_FLUX
|
||||
if flux == "chroma":
|
||||
impl = IMPL_CHROMA
|
||||
else:
|
||||
impl = IMPL_FLUX
|
||||
elif lumina is not None:
|
||||
# Lumina
|
||||
impl = IMPL_LUMINA
|
||||
|
||||
@@ -3482,7 +3482,7 @@ def get_sai_model_spec(
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
flux: str = None, # "dev", "schnell" or "chroma"
|
||||
lumina: str = None,
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -156,11 +156,19 @@ class LoRAModule(torch.nn.Module):
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
if (
|
||||
self.training
|
||||
and self.ggpo_sigma is not None
|
||||
and self.ggpo_beta is not None
|
||||
and self.combined_weight_norms is not None
|
||||
and self.grad_norms is not None
|
||||
):
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
|
||||
self.ggpo_beta * (self.grad_norms**2)
|
||||
)
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
@@ -197,24 +205,24 @@ class LoRAModule(torch.nn.Module):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@@ -223,37 +231,36 @@ class LoRAModule(torch.nn.Module):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
return {
|
||||
"true_mean_norm": true_mean_norm,
|
||||
"estimated_norm": estimated_norm,
|
||||
"absolute_error": absolute_error,
|
||||
"relative_error": relative_error,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
@@ -261,7 +268,7 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
# only update norms when we are training
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
@@ -269,8 +276,9 @@ class LoRAModule(torch.nn.Module):
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
self.combined_weight_norms = torch.sqrt(
|
||||
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
@@ -293,7 +301,6 @@ class LoRAModule(torch.nn.Module):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@@ -564,7 +571,6 @@ def create_network(
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -575,6 +581,42 @@ def create_network(
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
|
||||
# regex-specific learning rates
|
||||
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
|
||||
"""
|
||||
Parse a string of key-value pairs separated by commas.
|
||||
"""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
# parse regular expression based learning rates
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
if network_reg_lrs is not None:
|
||||
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
|
||||
else:
|
||||
reg_lrs = None
|
||||
|
||||
# regex-specific dimensions (ranks)
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
if network_reg_dims is not None:
|
||||
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
|
||||
else:
|
||||
reg_dims = None
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -594,8 +636,10 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
reg_dims=reg_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -613,7 +657,6 @@ def create_network(
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
@@ -644,22 +687,6 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
if train_t5xxl is None:
|
||||
train_t5xxl = False
|
||||
|
||||
# # split qkv
|
||||
# double_qkv_rank = None
|
||||
# single_qkv_rank = None
|
||||
# rank = None
|
||||
# for lora_name, dim in modules_dim.items():
|
||||
# if "double" in lora_name and "qkv" in lora_name:
|
||||
# double_qkv_rank = dim
|
||||
# elif "single" in lora_name and "linear1" in lora_name:
|
||||
# single_qkv_rank = dim
|
||||
# elif rank is None:
|
||||
# rank = dim
|
||||
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
|
||||
# break
|
||||
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
|
||||
# single_qkv_rank is not None and single_qkv_rank != rank
|
||||
# )
|
||||
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
@@ -708,8 +735,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -730,6 +759,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.in_dims = in_dims
|
||||
self.train_double_block_indices = train_double_block_indices
|
||||
self.train_single_block_indices = train_single_block_indices
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -757,7 +788,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -803,8 +833,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
elif self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.search(reg, lora_name):
|
||||
dim = d
|
||||
alpha = self.alpha
|
||||
logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}")
|
||||
break
|
||||
|
||||
# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
|
||||
if dim is None and modules_dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
@@ -892,6 +930,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
index = i
|
||||
if text_encoder is None:
|
||||
logger.info(f"Text Encoder {index+1} is None, skipping LoRA creation for this encoder.")
|
||||
continue
|
||||
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
break
|
||||
|
||||
@@ -976,7 +1017,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
@@ -1163,17 +1203,77 @@ class LoRANetwork(torch.nn.Module):
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
# regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...}
|
||||
reg_groups = {}
|
||||
|
||||
for lora in loras:
|
||||
# check if this lora matches any regex learning rate
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
try:
|
||||
if re.search(regex_str, lora.lora_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
except re.error:
|
||||
# regex error should have been caught during parsing, but just in case
|
||||
continue
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
param_key = f"{lora.lora_name}.{name}"
|
||||
is_plus = loraplus_ratio is not None and "lora_up" in name
|
||||
|
||||
if matched_reg_lr is not None:
|
||||
# use regex-specific learning rate
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
|
||||
if is_plus:
|
||||
reg_groups[group_key]["plus"][param_key] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][param_key] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
# use default learning rate
|
||||
if is_plus:
|
||||
param_groups["plus"][param_key] = param
|
||||
else:
|
||||
param_groups["lora"][param_key] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
|
||||
# process regex-specific groups first (higher priority)
|
||||
for group_key in sorted(reg_groups.keys()):
|
||||
group = reg_groups[group_key]
|
||||
reg_lr = group["lr"]
|
||||
|
||||
for param_type in ["lora", "plus"]:
|
||||
if len(group[param_type]) == 0:
|
||||
continue
|
||||
|
||||
param_data = {"params": group[param_type].values()}
|
||||
|
||||
if param_type == "plus" and loraplus_ratio is not None:
|
||||
param_data["lr"] = reg_lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
if param_type == "plus":
|
||||
desc += " plus"
|
||||
descriptions.append(desc)
|
||||
|
||||
# process default groups
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
|
||||
@@ -645,7 +645,7 @@ class NetworkTrainer:
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=")
|
||||
key, value = net_arg.split("=", 1)
|
||||
net_kwargs[key] = value
|
||||
|
||||
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||
|
||||
Reference in New Issue
Block a user