Compare commits

...

11 Commits

Author SHA1 Message Date
Kohya S
a437949d47 feat: Add support for Safetensors format in caching strategies (WIP)
- Introduced Safetensors output format for various caching strategies including Hunyuan, Lumina, SD, SDXL, and SD3.
- Updated methods to handle loading and saving of tensors in Safetensors format.
- Enhanced output validation to check for required tensors in both NPZ and Safetensors formats.
- Modified dataset argument parser to include `--cache_format` option for selecting between NPZ and Safetensors formats.
- Updated caching logic to accommodate partial loading and merging of existing Safetensors files.
2026-03-22 21:15:12 +09:00
Kohya S
bd19e4c15d Merge branch 'main' into sd3 2026-03-22 21:10:51 +09:00
Kohya S.
7c159291e9 docs: add skip_image_resolution to config README (#2288)
* docs: add skip_image_resolution option to config README

Document the skip_image_resolution dataset option added in PR #2273.
Add option description, multi-resolution dataset TOML example, and
command-line argument entry to both Japanese and English config READMEs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: clarify `skip_image_resolution` functionality in dataset config

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 09:17:29 +09:00
woctordho
1cd95b2d8b Add skip_image_resolution to deduplicate multi-resolution dataset (#2273)
* Add min_orig_resolution and max_orig_resolution

* Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution

* Change skip_image_resolution to tuple

* Move filtering to __init__

* Minor fix
2026-03-19 08:43:39 +09:00
Kohya S
d633b51126 Merge branch 'dev' into sd3 2026-02-26 08:22:30 +09:00
Kohya S.
2217704ce1 feat: Support LoKr/LoHa for SDXL and Anima (#2275)
* feat: Add LoHa/LoKr network support for SDXL and Anima

- networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection
- networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions
- networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions
- library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA

Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: Enhance LoHa and LoKr modules with Tucker decomposition support

- Added Tucker decomposition functionality to LoHa and LoKr modules.
- Implemented new methods for weight rebuilding using Tucker decomposition.
- Updated initialization and weight handling for Conv2d 3x3+ layers.
- Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes.
- Enhanced network base to include unet_conv_target_modules for architecture detection.

* fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details

* doc: add dtype comment for load_safetensors_with_lora_and_fp8 function

* fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py

* doc: update model support structure to include Lumina Image 2.0, HunyuanImage-2.1, and Anima-Preview

* doc: add documentation for LoHa and LoKr fine-tuning methods

* Update networks/network_base.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update docs/loha_lokr.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-23 22:09:00 +09:00
Kohya S.
f90fa1a89a feat: backward compatibility for SD/SDXL latent cache (#2276)
* fix: improve handling of legacy npz files and add logging for fallback scenarios

* fix: simplify fallback handling in SdSdxlLatentsCachingStrategy
2026-02-23 21:44:51 +09:00
Kohya S.
98a42e4cd6 Merge pull request #2277 from kohya-ss/feat-stability-with-fp16-for-anima
feat: Stability with fp16 for anima
2026-02-23 21:15:49 +09:00
Kohya S
892f8be78f fix: cast input tensor to float32 for improved numerical stability in residual connections 2026-02-23 21:12:57 +09:00
woctordho
50694df3cf Multi-resolution dataset for SD1/SDXL (#2269)
* Multi-resolution dataset for SD1/SDXL

* Add fallback to legacy key without resolution suffix

* Support numpy 2.2
2026-02-23 15:30:36 +09:00
duongve13112002
609d1292f6 Fix the LoRA dropout issue in the Anima model during LoRA training. (#2272)
* Support network_reg_alphas and fix bug when setting rank_dropout in training lora for anima model

* Update anima_train_network.md

* Update anima_train_network.md

* Remove network_reg_alphas

* Update document
2026-02-23 15:13:40 +09:00
26 changed files with 3869 additions and 572 deletions

View File

@@ -21,6 +21,9 @@ Each supported model family has a consistent structure:
- **SDXL**: `sdxl_train*.py`, `library/sdxl_*`
- **SD3**: `sd3_train*.py`, `library/sd3_*`
- **FLUX.1**: `flux_train*.py`, `library/flux_*`
- **Lumina Image 2.0**: `lumina_train*.py`, `library/lumina_*`
- **HunyuanImage-2.1**: `hunyuan_image_train*.py`, `library/hunyuan_image_*`
- **Anima-Preview**: `anima_train*.py`, `library/anima_*`
### Key Components

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
references

View File

@@ -122,11 +122,15 @@ These are options related to the configuration of the data set. They cannot be d
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
* `max_bucket_reso`, `min_bucket_reso`
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
* `skip_image_resolution`
* Images whose original resolution (area) is equal to or smaller than the specified resolution will be skipped. Specify as `'size'` or `[width, height]`. This corresponds to the command-line argument `--skip_image_resolution`.
* Useful when sharing the same image directory across multiple datasets with different resolutions, to exclude low-resolution source images from higher-resolution datasets.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
@@ -254,6 +258,34 @@ resolution = 768
image_dir = 'C:\hoge'
```
When using multi-resolution datasets, you can use `skip_image_resolution` to exclude images whose original size is too small for higher-resolution datasets. This prevents overlapping of low-resolution images across datasets and improves training quality. This option can also be used to simply exclude low-resolution source images from datasets.
```toml
[general]
enable_bucket = true
bucket_no_upscale = true
max_bucket_reso = 1536
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1024
skip_image_resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1280
skip_image_resolution = 1024
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
In this example, the 1024-resolution dataset skips images whose original size is 768x768 or smaller, and the 1280-resolution dataset skips images whose original size is 1024x1024 or smaller.
## Command Line Argument and Configuration File
There are options in the configuration file that have overlapping roles with command line argument options.
@@ -284,6 +316,7 @@ For the command line options listed below, if an option is specified in both the
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--skip_image_resolution` | |
| `--train_batch_size` | `batch_size` |
## Error Guide

View File

@@ -115,11 +115,15 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
* `batch_size`
* コマンドライン引数の `--train_batch_size` と同等です。
* `max_bucket_reso`, `min_bucket_reso`
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
* `skip_image_resolution`
* 指定した解像度(面積)以下の画像をスキップします。`'サイズ'` または `[幅, 高さ]` で指定します。コマンドライン引数の `--skip_image_resolution` と同等です。
* 同じ画像ディレクトリを異なる解像度の複数のデータセットで使い回す場合に、低解像度の元画像を高解像度のデータセットから除外するために使用します。
これらの設定はデータセットごとに固定です。
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
@@ -259,6 +263,34 @@ resolution = 768
image_dir = 'C:\hoge'
```
なお、マルチ解像度データセットでは `skip_image_resolution` を使用して、元の画像サイズが小さい画像を高解像度データセットから除外できます。これにより、低解像度画像のデータセット間での重複を防ぎ、学習品質を向上させることができます。また、小さい画像を除外するフィルターとしても機能します。
```toml
[general]
enable_bucket = true
bucket_no_upscale = true
max_bucket_reso = 1536
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1024
skip_image_resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1280
skip_image_resolution = 1024
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
この例では、1024 解像度のデータセットでは元の画像サイズが 768x768 以下の画像がスキップされ、1280 解像度のデータセットでは 1024x1024 以下の画像がスキップされます。
## コマンドライン引数との併用
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
@@ -289,6 +321,7 @@ resolution = 768
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--skip_image_resolution` | |
| `--train_batch_size` | `batch_size` |
## エラーの手引き

359
docs/loha_lokr.md Normal file
View File

@@ -0,0 +1,359 @@
> 📝 Click on the language section to expand / 言語をクリックして展開
# LoHa / LoKr (LyCORIS)
## Overview / 概要
In addition to standard LoRA, sd-scripts supports **LoHa** (Low-rank Hadamard Product) and **LoKr** (Low-rank Kronecker Product) as alternative parameter-efficient fine-tuning methods. These are based on techniques from the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project.
- **LoHa**: Represents weight updates as a Hadamard (element-wise) product of two low-rank matrices. Reference: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
- **LoKr**: Represents weight updates as a Kronecker product with optional low-rank decomposition. Reference: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
The algorithms and recommended settings are described in the [LyCORIS documentation](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md) and [guidelines](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md).
Both methods target Linear and Conv2d layers. Conv2d 1x1 layers are treated similarly to Linear layers. For Conv2d 3x3+ layers, optional Tucker decomposition or flat (kernel-flattened) mode is available.
This feature is experimental.
<details>
<summary>日本語</summary>
sd-scriptsでは、標準的なLoRAに加え、代替のパラメータ効率の良いファインチューニング手法として **LoHa**Low-rank Hadamard Product**LoKr**Low-rank Kronecker Productをサポートしています。これらは [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) プロジェクトの手法に基づいています。
- **LoHa**: 重みの更新を2つの低ランク行列のHadamard積要素ごとの積で表現します。参考文献: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
- **LoKr**: 重みの更新をKronecker積と、オプションの低ランク分解で表現します。参考文献: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
アルゴリズムと推奨設定は[LyCORISのアルゴリズム解説](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md)と[ガイドライン](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md)を参照してください。
LinearおよびConv2d層の両方を対象としています。Conv2d 1x1層はLinear層と同様に扱われます。Conv2d 3x3+層については、オプションのTucker分解またはflatカーネル平坦化モードが利用可能です。
この機能は実験的なものです。
</details>
## Acknowledgments / 謝辞
The LoHa and LoKr implementations in sd-scripts are based on the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project by [KohakuBlueleaf](https://github.com/KohakuBlueleaf). We would like to express our sincere gratitude for the excellent research and open-source contributions that made this implementation possible.
<details>
<summary>日本語</summary>
sd-scriptsのLoHaおよびLoKrの実装は、[KohakuBlueleaf](https://github.com/KohakuBlueleaf)氏による[LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS)プロジェクトに基づいています。この実装を可能にしてくださった素晴らしい研究とオープンソースへの貢献に心から感謝いたします。
</details>
## Supported architectures / 対応アーキテクチャ
LoHa and LoKr automatically detect the model architecture and apply appropriate default settings. The following architectures are currently supported:
- **SDXL**: Targets `Transformer2DModel` for UNet and `CLIPAttention`/`CLIPMLP` for text encoders. Conv2d layers in `ResnetBlock2D`, `Downsample2D`, and `Upsample2D` are also supported when `conv_dim` is specified. No default `exclude_patterns`.
- **Anima**: Targets `Block`, `PatchEmbed`, `TimestepEmbedding`, and `FinalLayer` for DiT, and `Qwen3Attention`/`Qwen3MLP` for the text encoder. Default `exclude_patterns` automatically skips modulation, normalization, embedder, and final_layer modules.
<details>
<summary>日本語</summary>
LoHaとLoKrは、モデルのアーキテクチャを自動で検出し、適切なデフォルト設定を適用します。現在、以下のアーキテクチャに対応しています:
- **SDXL**: UNetの`Transformer2DModel`、テキストエンコーダの`CLIPAttention`/`CLIPMLP`を対象とします。`conv_dim`を指定した場合、`ResnetBlock2D``Downsample2D``Upsample2D`のConv2d層も対象になります。デフォルトの`exclude_patterns`はありません。
- **Anima**: DiTの`Block``PatchEmbed``TimestepEmbedding``FinalLayer`、テキストエンコーダの`Qwen3Attention`/`Qwen3MLP`を対象とします。デフォルトの`exclude_patterns`により、modulation、normalization、embedder、final_layerモジュールは自動的にスキップされます。
</details>
## Training / 学習
To use LoHa or LoKr, change the `--network_module` argument in your training command. All other training options (dataset config, optimizer, etc.) remain the same as LoRA.
<details>
<summary>日本語</summary>
LoHaまたはLoKrを使用するには、学習コマンドの `--network_module` 引数を変更します。その他の学習オプションデータセット設定、オプティマイザなどはLoRAと同じです。
</details>
### LoHa (SDXL)
```bash
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
--pretrained_model_name_or_path path/to/sdxl.safetensors \
--dataset_config path/to/toml \
--mixed_precision bf16 --fp8_base \
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
--network_module networks.loha --network_dim 32 --network_alpha 16 \
--max_train_epochs 16 --save_every_n_epochs 1 \
--output_dir path/to/output --output_name my-loha
```
### LoKr (SDXL)
```bash
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
--pretrained_model_name_or_path path/to/sdxl.safetensors \
--dataset_config path/to/toml \
--mixed_precision bf16 --fp8_base \
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
--network_module networks.lokr --network_dim 32 --network_alpha 16 \
--max_train_epochs 16 --save_every_n_epochs 1 \
--output_dir path/to/output --output_name my-lokr
```
For Anima, replace `sdxl_train_network.py` with `anima_train_network.py` and use the appropriate model path and options.
<details>
<summary>日本語</summary>
Animaの場合は、`sdxl_train_network.py``anima_train_network.py` に置き換え、適切なモデルパスとオプションを使用してください。
</details>
### Common training options / 共通の学習オプション
The following `--network_args` options are available for both LoHa and LoKr, same as LoRA:
| Option | Description |
|---|---|
| `verbose=True` | Display detailed information about the network modules |
| `rank_dropout=0.1` | Apply dropout to the rank dimension during training |
| `module_dropout=0.1` | Randomly skip entire modules during training |
| `exclude_patterns=[r'...']` | Exclude modules matching the regex patterns (in addition to architecture defaults) |
| `include_patterns=[r'...']` | Override excludes: modules matching these regex patterns will be included even if they match `exclude_patterns` |
| `network_reg_lrs=regex1=lr1,regex2=lr2` | Set per-module learning rates using regex patterns |
| `network_reg_dims=regex1=dim1,regex2=dim2` | Set per-module dimensions (rank) using regex patterns |
<details>
<summary>日本語</summary>
以下の `--network_args` オプションは、LoRAと同様にLoHaとLoKrの両方で使用できます:
| オプション | 説明 |
|---|---|
| `verbose=True` | ネットワークモジュールの詳細情報を表示 |
| `rank_dropout=0.1` | 学習時にランク次元にドロップアウトを適用 |
| `module_dropout=0.1` | 学習時にモジュール全体をランダムにスキップ |
| `exclude_patterns=[r'...']` | 正規表現パターンに一致するモジュールを除外(アーキテクチャのデフォルトに追加) |
| `include_patterns=[r'...']` | 正規表現パターンに一致するモジュールのみを対象とする |
| `network_reg_lrs=regex1=lr1,regex2=lr2` | 正規表現パターンでモジュールごとの学習率を設定 |
| `network_reg_dims=regex1=dim1,regex2=dim2` | 正規表現パターンでモジュールごとの次元(ランク)を設定 |
</details>
### Conv2d support / Conv2dサポート
By default, LoHa and LoKr target Linear and Conv2d 1x1 layers. To also train Conv2d 3x3+ layers (e.g., in SDXL's ResNet blocks), use the `conv_dim` and `conv_alpha` options:
```bash
--network_args "conv_dim=16" "conv_alpha=8"
```
For Conv2d 3x3+ layers, you can enable Tucker decomposition for more efficient parameter representation:
```bash
--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
```
- Without `use_tucker`: The kernel dimensions are flattened into the input dimension (flat mode).
- With `use_tucker=True`: A separate Tucker tensor is used to handle the kernel dimensions, which can be more parameter-efficient.
<details>
<summary>日本語</summary>
デフォルトでは、LoHaとLoKrはLinearおよびConv2d 1x1層を対象とします。Conv2d 3x3+層SDXLのResNetブロックなども学習するには、`conv_dim``conv_alpha`オプションを使用します:
```bash
--network_args "conv_dim=16" "conv_alpha=8"
```
Conv2d 3x3+層に対して、Tucker分解を有効にすることで、より効率的なパラメータ表現が可能です:
```bash
--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
```
- `use_tucker`なし: カーネル次元が入力次元に平坦化されますflatモード
- `use_tucker=True`: カーネル次元を扱う別のTuckerテンソルが使用され、よりパラメータ効率が良くなる場合があります。
</details>
### LoKr-specific option: `factor` / LoKr固有のオプション: `factor`
LoKr decomposes weight dimensions using factorization. The `factor` option controls how dimensions are split:
- `factor=-1` (default): Automatically find balanced factors. For example, dimension 512 is split into (16, 32).
- `factor=N` (positive integer): Force factorization using the specified value. For example, `factor=4` splits dimension 512 into (4, 128).
```bash
--network_args "factor=4"
```
When `network_dim` (rank) is large enough relative to the factorized dimensions, LoKr uses a full matrix instead of a low-rank decomposition for the second factor. A warning will be logged in this case.
<details>
<summary>日本語</summary>
LoKrは重みの次元を因数分解して分割します。`factor` オプションでその分割方法を制御します:
- `factor=-1`(デフォルト): バランスの良い因数を自動的に見つけます。例えば、次元512は(16, 32)に分割されます。
- `factor=N`(正の整数): 指定した値で因数分解します。例えば、`factor=4` は次元512を(4, 128)に分割します。
```bash
--network_args "factor=4"
```
`network_dim`ランクが因数分解された次元に対して十分に大きい場合、LoKrは第2因子に低ランク分解ではなくフル行列を使用します。その場合、警告がログに出力されます。
</details>
### Anima-specific option: `train_llm_adapter` / Anima固有のオプション: `train_llm_adapter`
For Anima, you can additionally train the LLM adapter modules by specifying:
```bash
--network_args "train_llm_adapter=True"
```
This includes `LLMAdapterTransformerBlock` modules as training targets.
<details>
<summary>日本語</summary>
Animaでは、以下を指定することでLLMアダプターモジュールも追加で学習できます:
```bash
--network_args "train_llm_adapter=True"
```
これにより、`LLMAdapterTransformerBlock` モジュールが学習対象に含まれます。
</details>
### LoRA+ / LoRA+
LoRA+ (`loraplus_lr_ratio` etc. in `--network_args`) is supported with LoHa/LoKr. For LoHa, the second pair of matrices (`hada_w2_a`) is treated as the "plus" (higher learning rate) parameter group. For LoKr, the scale factor (`lokr_w1`) is treated as the "plus" parameter group.
```bash
--network_args "loraplus_lr_ratio=4"
```
This feature has been confirmed to work in basic testing, but feedback is welcome. If you encounter any issues, please report them.
<details>
<summary>日本語</summary>
LoRA+`--network_args``loraplus_lr_ratio`はLoHa/LoKrでもサポートされています。LoHaでは第2ペアの行列`hada_w2_a`が「plus」より高い学習率パラメータグループとして扱われます。LoKrではスケール係数`lokr_w1`が「plus」パラメータグループとして扱われます。
```bash
--network_args "loraplus_lr_ratio=4"
```
この機能は基本的なテストでは動作確認されていますが、フィードバックをお待ちしています。問題が発生した場合はご報告ください。
</details>
## How LoHa and LoKr work / LoHaとLoKrの仕組み
### LoHa
LoHa represents the weight update as a Hadamard (element-wise) product of two low-rank matrices:
```
ΔW = (W1a × W1b) ⊙ (W2a × W2b)
```
where `W1a`, `W1b`, `W2a`, `W2b` are low-rank matrices with rank `network_dim`. This means LoHa has roughly **twice the number of trainable parameters** compared to LoRA at the same rank, but can capture more complex weight structures due to the element-wise product.
For Conv2d 3x3+ layers with Tucker decomposition, each pair additionally has a Tucker tensor `T` and the reconstruction becomes: `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)`.
### LoKr
LoKr represents the weight update using a Kronecker product:
```
ΔW = W1 ⊗ W2 (where W2 = W2a × W2b in low-rank mode)
```
The original weight dimensions are factorized (e.g., a 512×512 weight might be split so that W1 is 16×16 and W2 is 32×32). W1 is always a full matrix (small), while W2 can be either low-rank decomposed or a full matrix depending on the rank setting. LoKr tends to produce **smaller models** compared to LoRA at the same rank.
<details>
<summary>日本語</summary>
### LoHa
LoHaは重みの更新を2つの低ランク行列のHadamard積要素ごとの積で表現します:
```
ΔW = (W1a × W1b) ⊙ (W2a × W2b)
```
ここで `W1a`, `W1b`, `W2a`, `W2b` はランク `network_dim` の低ランク行列です。LoHaは同じランクのLoRAと比較して学習可能なパラメータ数が **約2倍** になりますが、要素ごとの積により、より複雑な重み構造を捉えることができます。
Conv2d 3x3+層でTucker分解を使用する場合、各ペアにはさらにTuckerテンソル `T` があり、再構成は `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)` となります。
### LoKr
LoKrはKronecker積を使って重みの更新を表現します:
```
ΔW = W1 ⊗ W2 (低ランクモードでは W2 = W2a × W2b
```
元の重みの次元が因数分解されます(例: 512×512の重みが、W1が16×16、W2が32×32に分割されます。W1は常にフル行列小さいで、W2はランク設定に応じて低ランク分解またはフル行列になります。LoKrは同じランクのLoRAと比較して **より小さいモデル** を生成する傾向があります。
</details>
## Inference / 推論
Trained LoHa/LoKr weights are saved in safetensors format, just like LoRA.
<details>
<summary>日本語</summary>
学習済みのLoHa/LoKrの重みは、LoRAと同様にsafetensors形式で保存されます。
</details>
### SDXL
For SDXL, use `gen_img.py` with `--network_module` and `--network_weights`, the same way as LoRA:
```bash
python gen_img.py --ckpt path/to/sdxl.safetensors \
--network_module networks.loha --network_weights path/to/loha.safetensors \
--prompt "your prompt" ...
```
Replace `networks.loha` with `networks.lokr` for LoKr weights.
<details>
<summary>日本語</summary>
SDXLでは、LoRAと同様に `gen_img.py``--network_module``--network_weights` を指定します:
```bash
python gen_img.py --ckpt path/to/sdxl.safetensors \
--network_module networks.loha --network_weights path/to/loha.safetensors \
--prompt "your prompt" ...
```
LoKrの重みを使用する場合は `networks.loha``networks.lokr` に置き換えてください。
</details>
### Anima
For Anima, use `anima_minimal_inference.py` with the `--lora_weight` argument. LoRA, LoHa, and LoKr weights are automatically detected and merged:
```bash
python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
--lora_weight path/to/loha_or_lokr.safetensors ...
```
<details>
<summary>日本語</summary>
Animaでは、`anima_minimal_inference.py``--lora_weight` 引数を指定します。LoRA、LoHa、LoKrの重みは自動的に判定されてマージされます:
```bash
python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
--lora_weight path/to/loha_or_lokr.safetensors ...
```
</details>

View File

@@ -864,6 +864,10 @@ class Block(nn.Module):
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if x_B_T_H_W_D.dtype == torch.float16:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb

View File

@@ -108,6 +108,7 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
skip_image_resolution: Optional[Tuple[int, int]] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -244,6 +245,7 @@ class ConfigSanitizer:
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
}
# options handled by argparse but not handled by user config
@@ -256,6 +258,7 @@ class ConfigSanitizer:
ARGPARSE_NULLABLE_OPTNAMES = [
"face_crop_aug_range",
"resolution",
"skip_image_resolution",
]
# prepare map because option name may differ among argparse and user config
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
@@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
skip_image_resolution: {dataset.skip_image_resolution}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")

View File

@@ -6,6 +6,9 @@ from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from networks.loha import merge_weights_to_tensor as loha_merge
from networks.lokr import merge_weights_to_tensor as lokr_merge
from library.utils import setup_logging
setup_logging()
@@ -65,6 +68,7 @@ def load_safetensors_with_lora_and_fp8(
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
@@ -130,9 +134,9 @@ def load_safetensors_with_lora_and_fp8(
if down_key in lora_weight_keys and up_key in lora_weight_keys:
found = True
break
if not found:
continue # no LoRA weights for this model weight
if found:
# Standard LoRA merge
# get LoRA weights
down_weight = lora_sd[down_key]
up_weight = lora_sd[up_key]
@@ -180,6 +184,22 @@ def load_safetensors_with_lora_and_fp8(
lora_weight_keys.remove(up_key)
if alpha_key in lora_weight_keys:
lora_weight_keys.remove(alpha_key)
continue
# Check for LoHa/LoKr weights with same prefix search
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
hada_key = lora_name + ".hada_w1_a"
lokr_key = lora_name + ".lokr_w1"
if hada_key in lora_weight_keys:
# LoHa merge
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
break
elif lokr_key in lora_weight_keys:
# LoKr merge
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
break
if not keep_on_calc_device and original_device != calc_device:
model_weight = model_weight.to(original_device) # move back to original device

View File

@@ -155,6 +155,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
"""
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_anima_te.safetensors"
def __init__(
self,
@@ -166,7 +167,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.ANIMA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk:
@@ -177,6 +179,23 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "prompt_embeds"):
return False
if "attn_mask" not in keys:
return False
if "t5_input_ids" not in keys:
return False
if "t5_attn_mask" not in keys:
return False
if "caption_dropout_rate" not in keys:
return False
else:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
@@ -195,6 +214,19 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
prompt_embeds = f.get_tensor(_find_tensor_by_prefix(keys, "prompt_embeds")).numpy()
attn_mask = f.get_tensor("attn_mask").numpy()
t5_input_ids = f.get_tensor("t5_input_ids").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
caption_dropout_rate = f.get_tensor("caption_dropout_rate").numpy()
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
data = np.load(npz_path)
prompt_embeds = data["prompt_embeds"]
attn_mask = data["attn_mask"]
@@ -219,6 +251,9 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, tokens_and_masks
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos)
else:
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
@@ -246,6 +281,46 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
def _cache_batch_outputs_safetensors(self, prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
prompt_embeds = prompt_embeds.cpu()
attn_mask = attn_mask.cpu()
t5_input_ids = t5_input_ids.cpu().to(torch.int32)
t5_attn_mask = t5_attn_mask.cpu().to(torch.int32)
for i, info in enumerate(infos):
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
pe = prompt_embeds[i]
tensors[f"prompt_embeds_{_dtype_to_str(pe.dtype)}"] = pe
tensors["attn_mask"] = attn_mask[i]
tensors["t5_input_ids"] = t5_input_ids[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["caption_dropout_rate"] = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
metadata = {
"architecture": "anima",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
info.text_encoder_outputs = (
prompt_embeds[i].numpy(),
attn_mask[i].numpy(),
t5_input_ids[i].numpy(),
t5_attn_mask[i].numpy(),
caption_dropout_rate,
)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
"""Latent caching strategy for Anima using WanVAE.
@@ -255,16 +330,20 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
"""
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
ANIMA_LATENTS_ST_SUFFIX = "_anima.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.ANIMA_LATENTS_NPZ_SUFFIX
return self.ANIMA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "anima"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -2,7 +2,7 @@
import os
import re
from typing import Any, List, Optional, Tuple, Union, Callable
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import numpy as np
import torch
@@ -19,6 +19,48 @@ import logging
logger = logging.getLogger(__name__)
LATENTS_CACHE_FORMAT_VERSION = "1.0.1"
TE_OUTPUTS_CACHE_FORMAT_VERSION = "1.0.1"
# global cache format setting: "npz" or "safetensors"
_cache_format: str = "npz"
def set_cache_format(cache_format: str) -> None:
global _cache_format
_cache_format = cache_format
def get_cache_format() -> str:
return _cache_format
_TORCH_DTYPE_TO_STR = {
torch.float64: "float64",
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.int64: "int64",
torch.int32: "int32",
torch.int16: "int16",
torch.int8: "int8",
torch.uint8: "uint8",
torch.bool: "bool",
}
_FLOAT_DTYPES = {torch.float64, torch.float32, torch.float16, torch.bfloat16}
def _dtype_to_str(dtype: torch.dtype) -> str:
return _TORCH_DTYPE_TO_STR.get(dtype, str(dtype).replace("torch.", ""))
def _find_tensor_by_prefix(tensors_keys: List[str], prefix: str) -> Optional[str]:
"""Find a tensor key that starts with the given prefix. Returns the first match or None."""
for key in tensors_keys:
if key.startswith(prefix) or key == prefix:
return key
return None
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -362,6 +404,10 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self):
return self._is_weighted
@property
def cache_format(self) -> str:
return get_cache_format()
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
@@ -382,6 +428,8 @@ class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
_warned_fallback_to_old_npz = False # to avoid spamming logs about fallback
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
@@ -405,6 +453,10 @@ class LatentsCachingStrategy:
def batch_size(self):
return self._batch_size
@property
def cache_format(self) -> str:
return get_cache_format()
@property
def cache_suffix(self):
raise NotImplementedError
@@ -437,7 +489,7 @@ class LatentsCachingStrategy:
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz file
npz_path: path to the npz/safetensors file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents
@@ -452,6 +504,11 @@ class LatentsCachingStrategy:
if self.skip_disk_cache_validity_check:
return True
if npz_path.endswith(".safetensors"):
return self._is_disk_cached_latents_expected_safetensors(
latents_stride, bucket_reso, npz_path, flip_aug, apply_alpha_mask, multi_resolution
)
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
@@ -459,11 +516,14 @@ class LatentsCachingStrategy:
try:
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
# In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility)
# In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files).
if "latents" + key_reso_suffix not in npz and "latents" not in npz:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz):
return False
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz):
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -471,6 +531,40 @@ class LatentsCachingStrategy:
return True
def _is_disk_cached_latents_expected_safetensors(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
st_path: str,
flip_aug: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
) -> bool:
from library.safetensors_utils import MemoryEfficientSafeOpen
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # (H, W)
reso_tag = f"1x{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "1x"
try:
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_prefix = f"latents_{reso_tag}"
if not any(k.startswith(latents_prefix) for k in keys):
return False
if flip_aug:
flipped_prefix = f"latents_flipped_{reso_tag}"
if not any(k.startswith(flipped_prefix) for k in keys):
return False
if apply_alpha_mask:
mask_prefix = f"alpha_mask_{reso_tag}"
if not any(k.startswith(mask_prefix) for k in keys):
return False
except Exception as e:
logger.error(f"Error loading file: {st_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
@@ -543,7 +637,7 @@ class LatentsCachingStrategy:
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
For single resolution architectures (currently no architecture is single resolution specific). Kept for reference.
Args:
npz_path (str): Path to the npz file.
@@ -566,7 +660,7 @@ class LatentsCachingStrategy:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
npz_path (str): Path to the npz/safetensors file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
@@ -578,15 +672,27 @@ class LatentsCachingStrategy:
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if npz_path.endswith(".safetensors"):
return self._load_latents_from_disk_safetensors(latents_stride, npz_path, bucket_reso)
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
# raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
# Fallback to old npz without resolution suffix
if "latents" not in npz:
raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})")
if not self._warned_fallback_to_old_npz:
logger.warning(
f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version."
)
self._warned_fallback_to_old_npz = True
key_reso_suffix = ""
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
@@ -595,6 +701,39 @@ class LatentsCachingStrategy:
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def _load_latents_from_disk_safetensors(
self, latents_stride: Optional[int], st_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
from library.safetensors_utils import MemoryEfficientSafeOpen
if latents_stride is None:
reso_tag = "1x"
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride)
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
with MemoryEfficientSafeOpen(st_path) as f:
keys = f.keys()
latents_key = _find_tensor_by_prefix(keys, f"latents_{reso_tag}")
if latents_key is None:
raise ValueError(f"latents with prefix 'latents_{reso_tag}' not found in {st_path}")
latents = f.get_tensor(latents_key).numpy()
original_size_key = _find_tensor_by_prefix(keys, f"original_size_{reso_tag}")
original_size = f.get_tensor(original_size_key).numpy().tolist() if original_size_key else [0, 0]
crop_ltrb_key = _find_tensor_by_prefix(keys, f"crop_ltrb_{reso_tag}")
crop_ltrb = f.get_tensor(crop_ltrb_key).numpy().tolist() if crop_ltrb_key else [0, 0, 0, 0]
flipped_key = _find_tensor_by_prefix(keys, f"latents_flipped_{reso_tag}")
flipped_latents = f.get_tensor(flipped_key).numpy() if flipped_key else None
alpha_mask_key = _find_tensor_by_prefix(keys, f"alpha_mask_{reso_tag}")
alpha_mask = f.get_tensor(alpha_mask_key).numpy() if alpha_mask_key else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
@@ -607,17 +746,23 @@ class LatentsCachingStrategy:
):
"""
Args:
npz_path (str): Path to the npz file.
npz_path (str): Path to the npz/safetensors file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix
key_reso_suffix (str): Key resolution suffix (e.g. "_32x64" for multi-resolution npz)
Returns:
None
"""
if npz_path.endswith(".safetensors"):
self._save_latents_to_disk_safetensors(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor, alpha_mask, key_reso_suffix
)
return
kwargs = {}
if os.path.exists(npz_path):
@@ -626,7 +771,7 @@ class LatentsCachingStrategy:
for key in npz.files:
kwargs[key] = npz[key]
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
# float() is needed because npz doesn't support bfloat16
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
@@ -635,3 +780,59 @@ class LatentsCachingStrategy:
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)
def _save_latents_to_disk_safetensors(
self,
st_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
latents_tensor = latents_tensor.cpu()
latents_size = latents_tensor.shape[-2:] # H, W
reso_tag = f"1x{latents_size[0]}x{latents_size[1]}"
dtype_str = _dtype_to_str(latents_tensor.dtype)
# NaN check and zero replacement
if torch.isnan(latents_tensor).any():
latents_tensor = torch.nan_to_num(latents_tensor, nan=0.0)
tensors: Dict[str, torch.Tensor] = {}
# load existing file and merge (for multi-resolution)
if os.path.exists(st_path):
with MemoryEfficientSafeOpen(st_path) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
tensors[f"latents_{reso_tag}_{dtype_str}"] = latents_tensor
tensors[f"original_size_{reso_tag}_int32"] = torch.tensor(original_size, dtype=torch.int32)
tensors[f"crop_ltrb_{reso_tag}_int32"] = torch.tensor(crop_ltrb, dtype=torch.int32)
if flipped_latents_tensor is not None:
flipped_latents_tensor = flipped_latents_tensor.cpu()
if torch.isnan(flipped_latents_tensor).any():
flipped_latents_tensor = torch.nan_to_num(flipped_latents_tensor, nan=0.0)
tensors[f"latents_flipped_{reso_tag}_{dtype_str}"] = flipped_latents_tensor
if alpha_mask is not None:
alpha_mask_tensor = alpha_mask.cpu() if isinstance(alpha_mask, torch.Tensor) else torch.tensor(alpha_mask)
tensors[f"alpha_mask_{reso_tag}"] = alpha_mask_tensor
metadata = {
"architecture": self._get_architecture_name(),
"width": str(latents_size[1]),
"height": str(latents_size[0]),
"format_version": LATENTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, st_path, metadata=metadata)
def _get_architecture_name(self) -> str:
"""Override in subclasses to return the architecture name for safetensors metadata."""
return "unknown"

View File

@@ -87,6 +87,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_flux_te.safetensors"
def __init__(
self,
@@ -102,7 +103,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.FLUX_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -113,6 +115,26 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "l_pooled"):
return False
if not _find_tensor_by_prefix(keys, "t5_out"):
return False
if not _find_tensor_by_prefix(keys, "txt_ids"):
return False
if "t5_attn_mask" not in keys:
return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
@@ -134,6 +156,18 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
l_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "l_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
txt_ids = f.get_tensor(_find_tensor_by_prefix(keys, "txt_ids")).numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
@@ -161,6 +195,11 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
t5_attn_mask_tokens = tokens_and_masks[2]
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos)
else:
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
@@ -171,7 +210,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
@@ -193,24 +232,63 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(self, l_pooled, t5_out, txt_ids, t5_attn_mask_tokens, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
for i, info in enumerate(infos):
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
lp = l_pooled[i]
to = t5_out[i]
ti = txt_ids[i]
tensors[f"l_pooled_{_dtype_to_str(lp.dtype)}"] = lp
tensors[f"t5_out_{_dtype_to_str(to.dtype)}"] = to
tensors[f"txt_ids_{_dtype_to_str(ti.dtype)}"] = ti
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "flux",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (l_pooled[i].numpy(), t5_out[i].numpy(), txt_ids[i].numpy(), t5_attn_mask[i].numpy())
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
FLUX_LATENTS_ST_SUFFIX = "_flux.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
return self.FLUX_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "flux"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -81,16 +81,17 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_hi_te.npz"
HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_hi_te.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
suffix = self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return (
os.path.splitext(image_abs_path)[0]
+ HunyuanImageTextEncoderOutputsCachingStrategy.HUNYUAN_IMAGE_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
os.path.splitext(image_abs_path)[0] + suffix
)
def is_disk_cached_outputs_expected(self, npz_path: str):
@@ -102,6 +103,23 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "vlm_embed"):
return False
if "vlm_mask" not in keys:
return False
if not _find_tensor_by_prefix(keys, "byt5_embed"):
return False
if "byt5_mask" not in keys:
return False
if "ocr_mask" not in keys:
return False
else:
npz = np.load(npz_path)
if "vlm_embed" not in npz:
return False
@@ -120,6 +138,19 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
vlm_embed = f.get_tensor(_find_tensor_by_prefix(keys, "vlm_embed")).numpy()
vlm_mask = f.get_tensor("vlm_mask").numpy()
byt5_embed = f.get_tensor(_find_tensor_by_prefix(keys, "byt5_embed")).numpy()
byt5_mask = f.get_tensor("byt5_mask").numpy()
ocr_mask = f.get_tensor("ocr_mask").numpy()
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
data = np.load(npz_path)
vln_embed = data["vlm_embed"]
vlm_mask = data["vlm_mask"]
@@ -140,6 +171,9 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
tokenize_strategy, models, tokens_and_masks
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos)
else:
if vlm_embed.dtype == torch.bfloat16:
vlm_embed = vlm_embed.float()
if byt5_embed.dtype == torch.bfloat16:
@@ -170,24 +204,69 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
else:
info.text_encoder_outputs = (vlm_embed_i, vlm_mask_i, byt5_embed_i, byt5_mask_i, ocr_mask_i)
def _cache_batch_outputs_safetensors(self, vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
vlm_embed = vlm_embed.cpu()
vlm_mask = vlm_mask.cpu()
byt5_embed = byt5_embed.cpu()
byt5_mask = byt5_mask.cpu()
ocr_mask = ocr_mask.cpu()
for i, info in enumerate(infos):
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
ve = vlm_embed[i]
be = byt5_embed[i]
tensors[f"vlm_embed_{_dtype_to_str(ve.dtype)}"] = ve
tensors["vlm_mask"] = vlm_mask[i]
tensors[f"byt5_embed_{_dtype_to_str(be.dtype)}"] = be
tensors["byt5_mask"] = byt5_mask[i]
tensors["ocr_mask"] = ocr_mask[i]
metadata = {
"architecture": "hunyuan_image",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (
vlm_embed[i].numpy(),
vlm_mask[i].numpy(),
byt5_embed[i].numpy(),
byt5_mask[i].numpy(),
ocr_mask[i].numpy(),
)
class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX = "_hi.npz"
HUNYUAN_IMAGE_LATENTS_ST_SUFFIX = "_hi.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
return self.HUNYUAN_IMAGE_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ HunyuanImageLatentsCachingStrategy.HUNYUAN_IMAGE_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "hunyuan_image"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(32, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -146,6 +146,7 @@ class LuminaTextEncodingStrategy(TextEncodingStrategy):
class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"
LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_lumina_te.safetensors"
def __init__(
self,
@@ -162,19 +163,10 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)
suffix = self.LUMINA_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.
Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -183,6 +175,19 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state"):
return False
if "attention_mask" not in keys:
return False
if "input_ids" not in keys:
return False
else:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
@@ -198,11 +203,22 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz file
Load outputs from a npz/safetensors file
Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state")).numpy()
attention_mask = f.get_tensor("attention_mask").numpy()
input_ids = f.get_tensor("input_ids").numpy()
return [hidden_state, input_ids, attention_mask]
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
@@ -217,16 +233,6 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of ImageInfo
Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)
@@ -252,18 +258,20 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
)
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state, input_ids, attention_masks, batch)
else:
if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()
hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)
attention_mask = attention_masks.cpu().numpy()
input_ids_np = input_ids.cpu().numpy()
for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]
input_ids_i = input_ids_np[i]
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
@@ -280,9 +288,45 @@ class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy)
attention_mask_i,
]
def _cache_batch_outputs_safetensors(self, hidden_state, input_ids, attention_masks, batch):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state = hidden_state.cpu()
input_ids = input_ids.cpu()
attention_mask = attention_masks.cpu()
for i, info in enumerate(batch):
if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs = hidden_state[i]
tensors[f"hidden_state_{_dtype_to_str(hs.dtype)}"] = hs
tensors["attention_mask"] = attention_mask[i]
tensors["input_ids"] = input_ids[i]
metadata = {
"architecture": "lumina",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = [
hidden_state[i].numpy(),
input_ids[i].numpy(),
attention_mask[i].numpy(),
]
class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"
LUMINA_LATENTS_ST_SUFFIX = "_lumina.safetensors"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
@@ -291,7 +335,7 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
@property
def cache_suffix(self) -> str:
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
return self.LUMINA_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.LUMINA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
@@ -299,9 +343,12 @@ class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "lumina"
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],

View File

@@ -2,6 +2,7 @@ import glob
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer
from library import train_util
@@ -137,27 +138,40 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
SD_LATENTS_ST_SUFFIX = "_sd.safetensors"
SDXL_LATENTS_ST_SUFFIX = "_sdxl.safetensors"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
def __init__(
self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
if self.cache_format == "safetensors":
return self.SD_LATENTS_ST_SUFFIX if self.sd else self.SDXL_LATENTS_ST_SUFFIX
else:
return self.SD_LATENTS_NPZ_SUFFIX if self.sd else self.SDXL_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
if self.cache_format != "safetensors":
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def _get_architecture_name(self) -> str:
return "sd" if self.sd else "sdxl"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -165,7 +179,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -255,6 +255,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_sd3_te.safetensors"
def __init__(
self,
@@ -270,7 +271,8 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.SD3_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -281,12 +283,39 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "lg_out"):
return False
if not _find_tensor_by_prefix(keys, "lg_pooled"):
return False
if "clip_l_attn_mask" not in keys or "clip_g_attn_mask" not in keys:
return False
if not _find_tensor_by_prefix(keys, "t5_out"):
return False
if "t5_attn_mask" not in keys:
return False
if "apply_lg_attn_mask" not in keys:
return False
apply_lg = f.get_tensor("apply_lg_attn_mask").item()
if bool(apply_lg) != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in keys:
return False
apply_t5 = f.get_tensor("apply_t5_attn_mask").item()
if bool(apply_t5) != self.apply_t5_attn_mask:
return False
else:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz:
return False
if "apply_lg_attn_mask" not in npz:
return False
@@ -309,6 +338,20 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
lg_out = f.get_tensor(_find_tensor_by_prefix(keys, "lg_out")).numpy()
lg_pooled = f.get_tensor(_find_tensor_by_prefix(keys, "lg_pooled")).numpy()
t5_out = f.get_tensor(_find_tensor_by_prefix(keys, "t5_out")).numpy()
l_attn_mask = f.get_tensor("clip_l_attn_mask").numpy()
g_attn_mask = f.get_tensor("clip_g_attn_mask").numpy()
t5_attn_mask = f.get_tensor("t5_attn_mask").numpy()
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
@@ -339,6 +382,15 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
enable_dropout=False,
)
l_attn_mask_tokens = tokens_and_masks[3]
g_attn_mask_tokens = tokens_and_masks[4]
t5_attn_mask_tokens = tokens_and_masks[5]
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(
lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
)
else:
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
@@ -350,9 +402,9 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
l_attn_mask = l_attn_mask_tokens.cpu().numpy()
g_attn_mask = g_attn_mask_tokens.cpu().numpy()
t5_attn_mask = t5_attn_mask_tokens.cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
@@ -380,24 +432,77 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
def _cache_batch_outputs_safetensors(
self, lg_out, t5_out, lg_pooled, l_attn_mask_tokens, g_attn_mask_tokens, t5_attn_mask_tokens, infos
):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
lg_out = lg_out.cpu()
t5_out = t5_out.cpu()
lg_pooled = lg_pooled.cpu()
l_attn_mask = l_attn_mask_tokens.cpu()
g_attn_mask = g_attn_mask_tokens.cpu()
t5_attn_mask = t5_attn_mask_tokens.cpu()
for i, info in enumerate(infos):
if self.cache_to_disk:
tensors = {}
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
tensors[f"lg_out_{_dtype_to_str(lg_out_i.dtype)}"] = lg_out_i
tensors[f"t5_out_{_dtype_to_str(t5_out_i.dtype)}"] = t5_out_i
tensors[f"lg_pooled_{_dtype_to_str(lg_pooled_i.dtype)}"] = lg_pooled_i
tensors["clip_l_attn_mask"] = l_attn_mask[i]
tensors["clip_g_attn_mask"] = g_attn_mask[i]
tensors["t5_attn_mask"] = t5_attn_mask[i]
tensors["apply_lg_attn_mask"] = torch.tensor(self.apply_lg_attn_mask, dtype=torch.bool)
tensors["apply_t5_attn_mask"] = torch.tensor(self.apply_t5_attn_mask, dtype=torch.bool)
metadata = {
"architecture": "sd3",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = (
lg_out[i].numpy(),
t5_out[i].numpy(),
lg_pooled[i].numpy(),
l_attn_mask[i].numpy(),
g_attn_mask[i].numpy(),
t5_attn_mask[i].numpy(),
)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
SD3_LATENTS_ST_SUFFIX = "_sd3.safetensors"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
return self.SD3_LATENTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
+ self.cache_suffix
)
def _get_architecture_name(self) -> str:
return "sd3"
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)

View File

@@ -221,6 +221,7 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX = "_te_outputs.safetensors"
def __init__(
self,
@@ -233,7 +234,8 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
suffix = self.SDXL_TEXT_ENCODER_OUTPUTS_ST_SUFFIX if self.cache_format == "safetensors" else self.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
return os.path.splitext(image_abs_path)[0] + suffix
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
@@ -244,6 +246,19 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
try:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
if not _find_tensor_by_prefix(keys, "hidden_state1"):
return False
if not _find_tensor_by_prefix(keys, "hidden_state2"):
return False
if not _find_tensor_by_prefix(keys, "pool2"):
return False
else:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
@@ -254,6 +269,17 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
if npz_path.endswith(".safetensors"):
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.strategy_base import _find_tensor_by_prefix
with MemoryEfficientSafeOpen(npz_path) as f:
keys = f.keys()
hidden_state1 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state1")).numpy()
hidden_state2 = f.get_tensor(_find_tensor_by_prefix(keys, "hidden_state2")).numpy()
pool2 = f.get_tensor(_find_tensor_by_prefix(keys, "pool2")).numpy()
return [hidden_state1, hidden_state2, pool2]
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
@@ -279,6 +305,9 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2]
)
if self.cache_format == "safetensors":
self._cache_batch_outputs_safetensors(hidden_state1, hidden_state2, pool2, infos)
else:
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
@@ -304,3 +333,40 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
def _cache_batch_outputs_safetensors(self, hidden_state1, hidden_state2, pool2, infos):
from library.safetensors_utils import mem_eff_save_file, MemoryEfficientSafeOpen
from library.strategy_base import _dtype_to_str, TE_OUTPUTS_CACHE_FORMAT_VERSION
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
for i, info in enumerate(infos):
if self.cache_to_disk:
tensors = {}
# merge existing file if partial
if self.is_partial and os.path.exists(info.text_encoder_outputs_npz):
with MemoryEfficientSafeOpen(info.text_encoder_outputs_npz) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
hs1 = hidden_state1[i]
hs2 = hidden_state2[i]
p2 = pool2[i]
tensors[f"hidden_state1_{_dtype_to_str(hs1.dtype)}"] = hs1
tensors[f"hidden_state2_{_dtype_to_str(hs2.dtype)}"] = hs2
tensors[f"pool2_{_dtype_to_str(p2.dtype)}"] = p2
metadata = {
"architecture": "sdxl",
"caption1": info.caption,
"format_version": TE_OUTPUTS_CACHE_FORMAT_VERSION,
}
mem_eff_save_file(tensors, info.text_encoder_outputs_npz, metadata=metadata)
else:
info.text_encoder_outputs = [
hidden_state1[i].numpy(),
hidden_state2[i].numpy(),
pool2[i].numpy(),
]

View File

@@ -687,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
@@ -727,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset):
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
self.skip_image_resolution = skip_image_resolution
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -1103,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
return all(
[
not (
subset.caption_dropout_rate > 0 and not cache_supports_dropout
subset.caption_dropout_rate > 0
and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -1915,8 +1919,15 @@ class DreamBoothDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
skip_image_resolution,
)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2034,6 +2045,22 @@ class DreamBoothDataset(BaseDataset):
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
if self.skip_image_resolution is not None:
filtered_img_paths = []
filtered_sizes = []
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
for img_path, size in zip(img_paths, sizes):
if size is None: # no latents cache file, get image size by reading image file (slow)
size = self.get_image_size(img_path)
if size[0] * size[1] <= skip_image_area:
continue
filtered_img_paths.append(img_path)
filtered_sizes.append(size)
if len(filtered_img_paths) < len(img_paths):
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
img_paths = filtered_img_paths
sizes = filtered_sizes
# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
@@ -2059,7 +2086,7 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
captions = [metas[img_path]["caption"] for img_path in img_paths]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
@@ -2200,8 +2227,15 @@ class FineTuningDataset(BaseDataset):
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
skip_image_resolution,
)
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
@@ -2297,6 +2331,7 @@ class FineTuningDataset(BaseDataset):
tags_list = []
size_set_from_metadata = 0
size_set_from_cache_filename = 0
num_filtered = 0
for image_key in image_keys_sorted_by_length_desc:
img_md = metadata[image_key]
caption = img_md.get("caption")
@@ -2355,6 +2390,16 @@ class FineTuningDataset(BaseDataset):
image_info.image_size = (w, h)
size_set_from_cache_filename += 1
if self.skip_image_resolution is not None:
size = image_info.image_size
if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow)
size = self.get_image_size(abs_path)
image_info.image_size = size
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
if size[0] * size[1] <= skip_image_area:
num_filtered += 1
continue
self.register_image(image_info, subset)
if size_set_from_cache_filename > 0:
@@ -2363,6 +2408,8 @@ class FineTuningDataset(BaseDataset):
)
if size_set_from_metadata > 0:
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
if num_filtered > 0:
logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}")
self.num_train_images += len(metadata) * subset.num_repeats
# TODO do not record tag freq when no tag
@@ -2387,8 +2434,15 @@ class ControlNetDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
skip_image_resolution,
)
db_subsets = []
for subset in subsets:
@@ -2440,6 +2494,7 @@ class ControlNetDataset(BaseDataset):
validation_split,
validation_seed,
resize_interpolation,
skip_image_resolution,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2487,9 +2542,10 @@ class ControlNetDataset(BaseDataset):
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
if len(extra_imgs) > 0:
logger.warning(
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
)
self.conditioning_image_transforms = IMAGE_TRANSFORMS
@@ -4416,7 +4472,10 @@ def verify_training_args(args: argparse.Namespace):
Verify training arguments. Also reflect highvram option to global variable
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
"""
from library.strategy_base import set_cache_format
enable_high_vram(args)
set_cache_format(args.cache_format)
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -4582,6 +4641,14 @@ def add_dataset_arguments(
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument(
"--cache_format",
type=str,
default="npz",
choices=["npz", "safetensors"],
help="format for latent and text encoder output caches (default: npz). safetensors saves in native dtype (e.g. bf16) for smaller files and faster I/O"
" / latentおよびtext encoder出力キャッシュの保存形式デフォルト: npz。safetensorsはネイティブdtype例: bf16で保存し、ファイルサイズ削減と高速化が可能",
)
parser.add_argument(
"--enable_bucket",
action="store_true",
@@ -4601,6 +4668,13 @@ def add_dataset_arguments(
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--skip_image_resolution",
type=str,
default=None,
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
" / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)",
)
parser.add_argument(
"--bucket_reso_steps",
type=int,
@@ -5414,6 +5488,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
len(args.resolution) == 2
), f"resolution must be 'size' or 'width,height' / resolution解像度'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.skip_image_resolution is not None:
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
if len(args.skip_image_resolution) == 1:
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
assert (
len(args.skip_image_resolution) == 2
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'','高さ'で指定してください: {args.skip_image_resolution}"
if args.face_crop_aug_range is not None:
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
assert (

643
networks/loha.py Normal file
View File

@@ -0,0 +1,643 @@
# LoHa (Low-rank Hadamard Product) network module
# Reference: https://arxiv.org/abs/2108.06098
#
# Based on the LyCORIS project by KohakuBlueleaf
# https://github.com/KohakuBlueleaf/LyCORIS
import ast
import os
import logging
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
class HadaWeight(torch.autograd.Function):
"""Efficient Hadamard product forward/backward for LoHa.
Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
that recomputes intermediates instead of storing them.
"""
@staticmethod
def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
if scale is None:
scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2a @ w2b)
grad_w1a = temp @ w1b.T
grad_w1b = w1a.T @ temp
temp = grad_out * (w1a @ w1b)
grad_w2a = temp @ w2b.T
grad_w2b = w2a.T @ temp
del temp
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
class HadaWeightTucker(torch.autograd.Function):
"""Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
Compatible with LyCORIS parameter naming convention.
"""
@staticmethod
def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
if scale is None:
scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
grad_out = grad_out * scale
# Gradients for w1a, w1b, t1 (using rebuild2)
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
grad_w = rebuild * grad_out
del rebuild
grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
del grad_w, temp
grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
del grad_temp
# Gradients for w2a, w2b, t2 (using rebuild1)
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
grad_w = rebuild * grad_out
del rebuild
grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
del grad_w, temp
grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
del grad_temp
return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
class LoHaModule(torch.nn.Module):
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
use_tucker=False,
**kwargs,
):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
is_conv2d = org_module.__class__.__name__ == "Conv2d"
if is_conv2d:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
kernel_size = org_module.kernel_size
self.is_conv = True
self.stride = org_module.stride
self.padding = org_module.padding
self.dilation = org_module.dilation
self.groups = org_module.groups
self.kernel_size = kernel_size
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
if kernel_size == (1, 1):
self.conv_mode = "1x1"
elif self.tucker:
self.conv_mode = "tucker"
else:
self.conv_mode = "flat"
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.is_conv = False
self.tucker = False
self.conv_mode = None
self.kernel_size = None
self.in_dim = in_dim
self.out_dim = out_dim
# Create parameters based on mode
if self.conv_mode == "tucker":
# Tucker decomposition for Conv2d 3x3+
# Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
# LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
torch.nn.init.normal_(self.hada_t1, std=0.1)
torch.nn.init.normal_(self.hada_t2, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w1_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
elif self.conv_mode == "flat":
# Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
k_prod = 1
for k in kernel_size:
k_prod *= k
flat_in = in_dim * k_prod
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w2_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
else:
# Linear or Conv2d 1x1
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w2_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def get_diff_weight(self):
"""Return materialized weight delta.
Returns:
- Linear: 2D tensor (out_dim, in_dim)
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
- Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
"""
if self.tucker:
scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
return HadaWeightTucker.apply(
self.hada_t1, self.hada_w1_b, self.hada_w1_a,
self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
)
elif self.conv_mode == "flat":
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
else:
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
diff_weight = self.get_diff_weight()
# rank dropout (applied on output dimension)
if self.rank_dropout is not None and self.training:
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
diff_weight = diff_weight * drop
scale = 1.0 / (1.0 - self.rank_dropout)
else:
scale = 1.0
if self.is_conv:
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoHaInfModule(LoHaModule):
"""LoHa module for inference. Supports merge_to and get_weight."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference; pass use_tucker from kwargs
use_tucker = kwargs.pop("use_tucker", False)
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
self.org_module_ref = [org_module]
self.enabled = True
self.network: AdditionalNetwork = None
def set_network(self, network):
self.network = network
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float)
if dtype is None:
dtype = org_dtype
if device is None:
device = org_device
# get LoHa weights
w1a = sd["hada_w1_a"].to(torch.float).to(device)
w1b = sd["hada_w1_b"].to(torch.float).to(device)
w2a = sd["hada_w2_a"].to(torch.float).to(device)
w2b = sd["hada_w2_b"].to(torch.float).to(device)
if self.tucker:
# Tucker mode
t1 = sd["hada_t1"].to(torch.float).to(device)
t2 = sd["hada_t2"].to(torch.float).to(device)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
diff_weight = rebuild1 * rebuild2 * self.scale
else:
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
# reshape diff_weight to match original weight shape if needed
if diff_weight.shape != weight.shape:
diff_weight = diff_weight.reshape(weight.shape)
weight = weight.to(device) + self.multiplier * diff_weight
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
if self.tucker:
t1 = self.hada_t1.to(torch.float)
w1a = self.hada_w1_a.to(torch.float)
w1b = self.hada_w1_b.to(torch.float)
t2 = self.hada_t2.to(torch.float)
w2a = self.hada_w2_a.to(torch.float)
w2b = self.hada_w2_b.to(torch.float)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
weight = rebuild1 * rebuild2 * self.scale * multiplier
else:
w1a = self.hada_w1_a.to(torch.float)
w1b = self.hada_w1_b.to(torch.float)
w2a = self.hada_w2_a.to(torch.float)
w2b = self.hada_w2_b.to(torch.float)
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
if self.is_conv:
if self.conv_mode == "1x1":
weight = weight.unsqueeze(2).unsqueeze(3)
elif self.conv_mode == "flat":
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
return weight
def default_forward(self, x):
diff_weight = self.get_diff_weight()
if self.is_conv:
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier
else:
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae,
text_encoder,
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
"""Create a LoHa network. Called by train_network.py via network_module.create_network()."""
if network_dim is None:
network_dim = 4
if network_alpha is None:
network_alpha = 1.0
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
# exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns from arch config
exclude_patterns.extend(arch_config.default_excludes)
# include patterns
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# conv dim/alpha for Conv2d 3x3
conv_lora_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_lora_dim is not None:
conv_lora_dim = int(conv_lora_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# Tucker decomposition for Conv2d 3x3
use_tucker = kwargs.get("use_tucker", "false")
if use_tucker is not None:
use_tucker = True if str(use_tucker).lower() == "true" else False
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if str(verbose).lower() == "true" else False
# regex-specific learning rates / dimensions
network_reg_lrs = kwargs.get("network_reg_lrs", None)
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
network_reg_dims = kwargs.get("network_reg_dims", None)
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
module_class=LoHaModule,
module_kwargs={"use_tucker": use_tucker},
conv_lora_dim=conv_lora_dim,
conv_alpha=conv_alpha,
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
# LoRA+ support
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
"""Create a LoHa network from saved weights. Called by train_network.py."""
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# detect dim/alpha from weights
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "hada_w1_b" in key:
dim = value.shape[0]
modules_dim[lora_name] = dim
if "llm_adapter" in lora_name:
train_llm_adapter = True
# detect Tucker mode from weights
use_tucker = any("hada_t1" in key for key in weights_sd.keys())
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
module_class = LoHaInfModule if for_inference else LoHaModule
module_kwargs = {"use_tucker": use_tucker}
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
module_kwargs=module_kwargs,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
def merge_weights_to_tensor(
model_weight: torch.Tensor,
lora_name: str,
lora_sd: Dict[str, torch.Tensor],
lora_weight_keys: set,
multiplier: float,
calc_device: torch.device,
) -> torch.Tensor:
"""Merge LoHa weights directly into a model weight tensor.
Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
Returns model_weight unchanged if no matching LoHa keys found.
"""
w1a_key = lora_name + ".hada_w1_a"
w1b_key = lora_name + ".hada_w1_b"
w2a_key = lora_name + ".hada_w2_a"
w2b_key = lora_name + ".hada_w2_b"
t1_key = lora_name + ".hada_t1"
t2_key = lora_name + ".hada_t2"
alpha_key = lora_name + ".alpha"
if w1a_key not in lora_weight_keys:
return model_weight
w1a = lora_sd[w1a_key].to(calc_device)
w1b = lora_sd[w1b_key].to(calc_device)
w2a = lora_sd[w2a_key].to(calc_device)
w2b = lora_sd[w2b_key].to(calc_device)
has_tucker = t1_key in lora_weight_keys
dim = w1b.shape[0]
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
if isinstance(alpha, torch.Tensor):
alpha = alpha.item()
scale = alpha / dim
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(torch.float16)
w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
if has_tucker:
# Tucker decomposition: rebuild via einsum
t1 = lora_sd[t1_key].to(calc_device)
t2 = lora_sd[t2_key].to(calc_device)
if original_dtype.itemsize == 1:
t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
diff_weight = rebuild1 * rebuild2 * scale
else:
# Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
# Reshape diff_weight to match model_weight shape if needed
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
if diff_weight.shape != model_weight.shape:
diff_weight = diff_weight.reshape(model_weight.shape)
model_weight = model_weight + multiplier * diff_weight
if original_dtype.itemsize == 1:
model_weight = model_weight.to(original_dtype)
# remove consumed keys
consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
if has_tucker:
consumed.extend([t1_key, t2_key])
for key in consumed:
lora_weight_keys.discard(key)
return model_weight

683
networks/lokr.py Normal file
View File

@@ -0,0 +1,683 @@
# LoKr (Low-rank Kronecker Product) network module
# Reference: https://arxiv.org/abs/2309.14859
#
# Based on the LyCORIS project by KohakuBlueleaf
# https://github.com/KohakuBlueleaf/LyCORIS
import ast
import math
import os
import logging
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def factorization(dimension: int, factor: int = -1) -> tuple:
"""Return a tuple of two values whose product equals dimension,
optimized for balanced factors.
In LoKr, the first value is for the weight scale (smaller),
and the second value is for the weight (larger).
Examples:
factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32)
factor=4: 128 -> (4, 32), 512 -> (4, 128)
"""
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n
def make_kron(w1, w2, scale):
"""Compute Kronecker product of w1 and w2, scaled by scale."""
if w1.dim() != w2.dim():
for _ in range(w2.dim() - w1.dim()):
w1 = w1.unsqueeze(-1)
w2 = w2.contiguous()
rebuild = torch.kron(w1, w2)
if scale != 1:
rebuild = rebuild * scale
return rebuild
def rebuild_tucker(t, wa, wb):
"""Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb).
Compatible with LyCORIS convention.
"""
return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb)
class LoKrModule(torch.nn.Module):
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
factor=-1,
use_tucker=False,
**kwargs,
):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
is_conv2d = org_module.__class__.__name__ == "Conv2d"
if is_conv2d:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
kernel_size = org_module.kernel_size
self.is_conv = True
self.stride = org_module.stride
self.padding = org_module.padding
self.dilation = org_module.dilation
self.groups = org_module.groups
self.kernel_size = kernel_size
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
if kernel_size == (1, 1):
self.conv_mode = "1x1"
elif self.tucker:
self.conv_mode = "tucker"
else:
self.conv_mode = "flat"
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.is_conv = False
self.tucker = False
self.conv_mode = None
self.kernel_size = None
self.in_dim = in_dim
self.out_dim = out_dim
factor = int(factor)
self.use_w2 = False
# Factorize dimensions
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
# w1 is always a full matrix (the "scale" factor, small)
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
# w2: depends on mode
if self.conv_mode in ("tucker", "flat"):
# Conv2d 3x3+ modes
k_size = kernel_size
if lora_dim >= max(out_k, in_n) / 2:
# Full matrix mode (includes kernel dimensions)
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size))
logger.warning(
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
f"and factor={factor}, using full matrix mode for Conv2d."
)
elif self.tucker:
# Tucker mode: separate kernel into t2 tensor
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size))
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
else:
# Non-Tucker: flatten kernel into w2_b
k_prod = 1
for k in k_size:
k_prod *= k
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod))
else:
# Linear or Conv2d 1x1
if lora_dim < max(out_k, in_n) / 2:
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
if lora_dim >= max(out_k, in_n) / 2:
logger.warning(
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
f"and factor={factor}, using full matrix mode."
)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
# if both w1 and w2 are full matrices, use scale = 1
if self.use_w2:
alpha = lora_dim
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha))
# Initialization
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
else:
if self.tucker:
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
torch.nn.init.constant_(self.lokr_w2_b, 0)
# Ensures ΔW = kron(w1, 0) = 0 at init
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def get_diff_weight(self):
"""Return materialized weight delta.
Returns:
- Linear: 2D tensor (out_dim, in_dim)
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
- Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2)
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D
"""
w1 = self.lokr_w1
if self.use_w2:
w2 = self.lokr_w2
elif self.tucker:
w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
result = make_kron(w1, w2, self.scale)
# For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D
if self.conv_mode == "flat" and result.dim() == 2:
result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size)
return result
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
diff_weight = self.get_diff_weight()
# rank dropout
if self.rank_dropout is not None and self.training:
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
diff_weight = diff_weight * drop
scale = 1.0 / (1.0 - self.rank_dropout)
else:
scale = 1.0
if self.is_conv:
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoKrInfModule(LoKrModule):
"""LoKr module for inference. Supports merge_to and get_weight."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference; pass factor and use_tucker from kwargs
factor = kwargs.pop("factor", -1)
use_tucker = kwargs.pop("use_tucker", False)
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker)
self.org_module_ref = [org_module]
self.enabled = True
self.network: AdditionalNetwork = None
def set_network(self, network):
self.network = network
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float)
if dtype is None:
dtype = org_dtype
if device is None:
device = org_device
# get LoKr weights
w1 = sd["lokr_w1"].to(torch.float).to(device)
if "lokr_w2" in sd:
w2 = sd["lokr_w2"].to(torch.float).to(device)
elif "lokr_t2" in sd:
# Tucker mode
t2 = sd["lokr_t2"].to(torch.float).to(device)
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
w2 = rebuild_tucker(t2, w2a, w2b)
else:
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
w2 = w2a @ w2b
# compute ΔW via Kronecker product
diff_weight = make_kron(w1, w2, self.scale)
# reshape diff_weight to match original weight shape if needed
if diff_weight.shape != weight.shape:
diff_weight = diff_weight.reshape(weight.shape)
weight = weight.to(device) + self.multiplier * diff_weight
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
w1 = self.lokr_w1.to(torch.float)
if self.use_w2:
w2 = self.lokr_w2.to(torch.float)
elif self.tucker:
w2 = rebuild_tucker(
self.lokr_t2.to(torch.float),
self.lokr_w2_a.to(torch.float),
self.lokr_w2_b.to(torch.float),
)
else:
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
weight = make_kron(w1, w2, self.scale) * multiplier
# reshape to match original weight shape if needed
if self.is_conv:
if self.conv_mode == "1x1":
weight = weight.unsqueeze(2).unsqueeze(3)
elif self.conv_mode == "flat" and weight.dim() == 2:
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
# Tucker and full matrix modes: already 4D from kron
return weight
def default_forward(self, x):
diff_weight = self.get_diff_weight()
if self.is_conv:
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier
else:
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae,
text_encoder,
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
"""Create a LoKr network. Called by train_network.py via network_module.create_network()."""
if network_dim is None:
network_dim = 4
if network_alpha is None:
network_alpha = 1.0
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
# exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns from arch config
exclude_patterns.extend(arch_config.default_excludes)
# include patterns
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# conv dim/alpha for Conv2d 3x3
conv_lora_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_lora_dim is not None:
conv_lora_dim = int(conv_lora_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# Tucker decomposition for Conv2d 3x3
use_tucker = kwargs.get("use_tucker", "false")
if use_tucker is not None:
use_tucker = True if str(use_tucker).lower() == "true" else False
# factor for LoKr
factor = int(kwargs.get("factor", -1))
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if str(verbose).lower() == "true" else False
# regex-specific learning rates / dimensions
network_reg_lrs = kwargs.get("network_reg_lrs", None)
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
network_reg_dims = kwargs.get("network_reg_dims", None)
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
module_class=LoKrModule,
module_kwargs={"factor": factor, "use_tucker": use_tucker},
conv_lora_dim=conv_lora_dim,
conv_alpha=conv_alpha,
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
# LoRA+ support
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
"""Create a LoKr network from saved weights. Called by train_network.py."""
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# detect dim/alpha from weights
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
use_tucker = False
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lokr_w2_a" in key:
# low-rank mode: dim detection depends on Tucker vs non-Tucker
if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd:
# Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0]
dim = value.shape[0]
else:
# Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1]
dim = value.shape[1]
modules_dim[lora_name] = dim
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
# full matrix mode: set dim large enough to trigger full-matrix path
if lora_name not in modules_dim:
modules_dim[lora_name] = max(value.shape[0], value.shape[1])
if "lokr_t2" in key:
use_tucker = True
if "llm_adapter" in lora_name:
train_llm_adapter = True
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
# extract factor for LoKr
factor = int(kwargs.get("factor", -1))
module_class = LoKrInfModule if for_inference else LoKrModule
module_kwargs = {"factor": factor, "use_tucker": use_tucker}
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
module_kwargs=module_kwargs,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
def merge_weights_to_tensor(
model_weight: torch.Tensor,
lora_name: str,
lora_sd: Dict[str, torch.Tensor],
lora_weight_keys: set,
multiplier: float,
calc_device: torch.device,
) -> torch.Tensor:
"""Merge LoKr weights directly into a model weight tensor.
Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
Returns model_weight unchanged if no matching LoKr keys found.
"""
w1_key = lora_name + ".lokr_w1"
w2_key = lora_name + ".lokr_w2"
w2a_key = lora_name + ".lokr_w2_a"
w2b_key = lora_name + ".lokr_w2_b"
t2_key = lora_name + ".lokr_t2"
alpha_key = lora_name + ".alpha"
if w1_key not in lora_weight_keys:
return model_weight
w1 = lora_sd[w1_key].to(calc_device)
# determine mode: full matrix vs Tucker vs low-rank
has_tucker = t2_key in lora_weight_keys
if w2a_key in lora_weight_keys:
w2a = lora_sd[w2a_key].to(calc_device)
w2b = lora_sd[w2b_key].to(calc_device)
if has_tucker:
# Tucker: w2a = (rank, out_k), dim = rank
dim = w2a.shape[0]
else:
# Non-Tucker low-rank: w2a = (out_k, rank), dim = rank
dim = w2a.shape[1]
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
if has_tucker:
consumed_keys.append(t2_key)
elif w2_key in lora_weight_keys:
# full matrix mode
w2a = None
w2b = None
dim = None
consumed_keys = [w1_key, w2_key, alpha_key]
else:
return model_weight
alpha = lora_sd.get(alpha_key, None)
if alpha is not None and isinstance(alpha, torch.Tensor):
alpha = alpha.item()
# compute scale
if w2a is not None:
if alpha is None:
alpha = dim
scale = alpha / dim
else:
# full matrix mode: scale = 1.0
scale = 1.0
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(torch.float16)
w1 = w1.to(torch.float16)
if w2a is not None:
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
# compute w2
if w2a is not None:
if has_tucker:
t2 = lora_sd[t2_key].to(calc_device)
if original_dtype.itemsize == 1:
t2 = t2.to(torch.float16)
w2 = rebuild_tucker(t2, w2a, w2b)
else:
w2 = w2a @ w2b
else:
w2 = lora_sd[w2_key].to(calc_device)
if original_dtype.itemsize == 1:
w2 = w2.to(torch.float16)
# ΔW = kron(w1, w2) * scale
diff_weight = make_kron(w1, w2, scale)
# Reshape diff_weight to match model_weight shape if needed
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
if diff_weight.shape != model_weight.shape:
diff_weight = diff_weight.reshape(model_weight.shape)
model_weight = model_weight + multiplier * diff_weight
if original_dtype.itemsize == 1:
model_weight = model_weight.to(original_dtype)
# remove consumed keys
for key in consumed_keys:
lora_weight_keys.discard(key)
return model_weight

View File

@@ -1,11 +1,11 @@
# LoRA network module for Anima
import ast
import math
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
import logging
@@ -13,6 +13,213 @@ setup_logging()
logger = logging.getLogger(__name__)
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
"""
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if isinstance(self.lora_down, torch.nn.Conv2d):
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
mask = mask.unsqueeze(-1).unsqueeze(-1)
else:
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
for _ in range(len(lx.size()) - 2):
mask = mask.unsqueeze(1)
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoRAInfModule(LoRAModule):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
# freezeしてマージする
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float) # calc in float
if dtype is None:
dtype = org_dtype
if device is None:
device = org_device
# get up/down weight
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale
return weight
def default_forward(self, x):
# logger.info(f"default_forward {self.lora_name} {x.size()}")
lx = self.lora_down(x)
lx = self.lora_up(lx)
return self.org_forward(x) + lx * self.multiplier * self.scale
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)
def create_network(
multiplier: float,
network_dim: Optional[int],

View File

@@ -141,10 +141,13 @@ class LoRAModule(torch.nn.Module):
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
if isinstance(self.lora_down, torch.nn.Conv2d):
# Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
mask = mask.unsqueeze(-1).unsqueeze(-1)
else:
# Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
for _ in range(len(lx.size()) - 2):
mask = mask.unsqueeze(1)
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed

545
networks/network_base.py Normal file
View File

@@ -0,0 +1,545 @@
# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc).
# Provides architecture detection and a generic AdditionalNetwork class.
import os
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
@dataclass
class ArchConfig:
unet_target_modules: List[str]
te_target_modules: List[str]
unet_prefix: str
te_prefixes: List[str]
default_excludes: List[str] = field(default_factory=list)
adapter_target_modules: List[str] = field(default_factory=list)
unet_conv_target_modules: List[str] = field(default_factory=list)
def detect_arch_config(unet, text_encoders) -> ArchConfig:
"""Detect architecture from model structure and return ArchConfig."""
from library.sdxl_original_unet import SdxlUNet2DConditionModel
# Check SDXL first
if unet is not None and (
issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
):
return ArchConfig(
unet_target_modules=["Transformer2DModel"],
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
unet_prefix="lora_unet",
te_prefixes=["lora_te1", "lora_te2"],
default_excludes=[],
unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"],
)
# Check Anima: look for Block class in named_modules
module_class_names = set()
if unet is not None:
for module in unet.modules():
module_class_names.add(type(module).__name__)
if "Block" in module_class_names:
return ArchConfig(
unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"],
te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"],
unet_prefix="lora_unet",
te_prefixes=["lora_te"],
default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"],
adapter_target_modules=["LLMAdapterTransformerBlock"],
)
raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}")
def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]:
"""Parse a string of key-value pairs separated by commas."""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
class AdditionalNetwork(torch.nn.Module):
"""Generic Additional network that supports LoHa, LoKr, and similar module types.
Constructed with a module_class parameter to inject the specific module type.
Based on the lora_anima.py LoRANetwork, generalized for multiple architectures.
"""
def __init__(
self,
text_encoders: list,
unet,
arch_config: ArchConfig,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
module_class: Type[torch.nn.Module] = None,
module_kwargs: Optional[Dict] = None,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
train_llm_adapter: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
assert module_class is not None, "module_class must be specified"
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.train_llm_adapter = train_llm_adapter
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.arch_config = arch_config
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if module_kwargs is None:
module_kwargs = {}
if modules_dim is not None:
logger.info(f"create {module_class.__name__} network from weights")
else:
logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expressions
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
prefix: str,
root_module: torch.nn.Module,
target_replace_modules: List[str],
default_dim: Optional[int] = None,
) -> Tuple[List[torch.nn.Module], List[str]]:
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None:
module = root_module
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
# exclude/include filter
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
if modules_dim is not None:
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
break
# fallback to default dim
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
elif is_conv2d and self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha_val = self.conv_alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha_val,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
**module_kwargs,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
break
return loras, skipped
# Create modules for text encoders
self.text_encoder_loras: List[torch.nn.Module] = []
skipped_te = []
if text_encoders is not None:
for i, text_encoder in enumerate(text_encoders):
if text_encoder is None:
continue
# Determine prefix for this text encoder
if i < len(arch_config.te_prefixes):
te_prefix = arch_config.te_prefixes[i]
else:
te_prefix = arch_config.te_prefixes[0]
logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):")
te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules)
logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
# Create modules for UNet/DiT
target_modules = list(arch_config.unet_target_modules)
if modules_dim is not None or conv_lora_dim is not None:
target_modules.extend(arch_config.unet_conv_target_modules)
if train_llm_adapter and arch_config.adapter_target_modules:
target_modules.extend(arch_config.adapter_target_modules)
self.unet_loras: List[torch.nn.Module]
self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules)
logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:")
for name in skipped:
logger.info(f"\t{name}")
# assertion: no duplicate names
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
def is_mergeable(self):
return True
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
te_prefixes = self.arch_config.te_prefixes
unet_prefix = self.arch_config.unet_prefix
for key in weights_sd.keys():
if any(key.startswith(p) for p in te_prefixes):
apply_text_encoder = True
elif key.startswith(unet_prefix):
apply_unet = True
if apply_text_encoder:
logger.info("enable modules for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable modules for UNet/DiT")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
pass # already a list with one element
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
# LoRA+ detection: check for "up" weight parameters
if loraplus_ratio is not None and self._is_plus_param(name):
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and self._is_plus_param(name):
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# Group TE loras by prefix
for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes):
te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)]
if len(te_loras) > 0:
te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0]
logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}")
params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio)
all_params.extend(params)
lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def _is_plus_param(self, name: str) -> bool:
"""Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+.
For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor).
Override in subclass if needed. Default: check for common 'up' patterns.
"""
return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name
def enable_gradient_checkpointing(self):
pass # not supported
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
loras = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
loras = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
loras = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False

View File

@@ -69,6 +69,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
strategy_base.set_cache_format(args.cache_format)
if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
else:

View File

@@ -156,6 +156,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
text_encoder.eval()
# build text encoder outputs caching strategy
strategy_base.set_cache_format(args.cache_format)
if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions

View File

@@ -1085,6 +1085,7 @@ class NetworkTrainer:
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"skip_image_resolution": dataset.skip_image_resolution,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
@@ -1191,6 +1192,7 @@ class NetworkTrainer:
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_skip_image_resolution": dataset.skip_image_resolution,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),