From 2217704ce17c8650838627d46e9e8864762070b9 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 22:09:00 +0900 Subject: [PATCH] feat: Support LoKr/LoHa for SDXL and Anima (#2275) * feat: Add LoHa/LoKr network support for SDXL and Anima - networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection - networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions - networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions - library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately. Co-Authored-By: Claude Opus 4.6 * feat: Enhance LoHa and LoKr modules with Tucker decomposition support - Added Tucker decomposition functionality to LoHa and LoKr modules. - Implemented new methods for weight rebuilding using Tucker decomposition. - Updated initialization and weight handling for Conv2d 3x3+ layers. - Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes. - Enhanced network base to include unet_conv_target_modules for architecture detection. * fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details * doc: add dtype comment for load_safetensors_with_lora_and_fp8 function * fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py * doc: update model support structure to include Lumina Image 2.0, HunyuanImage-2.1, and Anima-Preview * doc: add documentation for LoHa and LoKr fine-tuning methods * Update networks/network_base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update docs/loha_lokr.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .ai/context/01-overview.md | 3 + .gitignore | 1 + docs/loha_lokr.md | 359 +++++++++++++++++++ library/lora_utils.py | 554 +++++++++++++++--------------- networks/loha.py | 643 ++++++++++++++++++++++++++++++++++ networks/lokr.py | 683 +++++++++++++++++++++++++++++++++++++ networks/lora_anima.py | 209 +++++++++++- networks/network_base.py | 545 +++++++++++++++++++++++++++++ 8 files changed, 2729 insertions(+), 268 deletions(-) create mode 100644 docs/loha_lokr.md create mode 100644 networks/loha.py create mode 100644 networks/lokr.py create mode 100644 networks/network_base.py diff --git a/.ai/context/01-overview.md b/.ai/context/01-overview.md index 41133e98..c37aba19 100644 --- a/.ai/context/01-overview.md +++ b/.ai/context/01-overview.md @@ -21,6 +21,9 @@ Each supported model family has a consistent structure: - **SDXL**: `sdxl_train*.py`, `library/sdxl_*` - **SD3**: `sd3_train*.py`, `library/sd3_*` - **FLUX.1**: `flux_train*.py`, `library/flux_*` +- **Lumina Image 2.0**: `lumina_train*.py`, `library/lumina_*` +- **HunyuanImage-2.1**: `hunyuan_image_train*.py`, `library/hunyuan_image_*` +- **Anima-Preview**: `anima_train*.py`, `library/anima_*` ### Key Components diff --git a/.gitignore b/.gitignore index cfdc0268..f5772a7f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ GEMINI.md .claude .gemini MagicMock +references \ No newline at end of file diff --git a/docs/loha_lokr.md b/docs/loha_lokr.md new file mode 100644 index 00000000..6f16ba66 --- /dev/null +++ b/docs/loha_lokr.md @@ -0,0 +1,359 @@ +> 📝 Click on the language section to expand / 言語をクリックして展開 + +# LoHa / LoKr (LyCORIS) + +## Overview / 概要 + +In addition to standard LoRA, sd-scripts supports **LoHa** (Low-rank Hadamard Product) and **LoKr** (Low-rank Kronecker Product) as alternative parameter-efficient fine-tuning methods. These are based on techniques from the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project. + +- **LoHa**: Represents weight updates as a Hadamard (element-wise) product of two low-rank matrices. Reference: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098) +- **LoKr**: Represents weight updates as a Kronecker product with optional low-rank decomposition. Reference: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859) + +The algorithms and recommended settings are described in the [LyCORIS documentation](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md) and [guidelines](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). + +Both methods target Linear and Conv2d layers. Conv2d 1x1 layers are treated similarly to Linear layers. For Conv2d 3x3+ layers, optional Tucker decomposition or flat (kernel-flattened) mode is available. + +This feature is experimental. + +
+日本語 + +sd-scriptsでは、標準的なLoRAに加え、代替のパラメータ効率の良いファインチューニング手法として **LoHa**(Low-rank Hadamard Product)と **LoKr**(Low-rank Kronecker Product)をサポートしています。これらは [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) プロジェクトの手法に基づいています。 + +- **LoHa**: 重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します。参考文献: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098) +- **LoKr**: 重みの更新をKronecker積と、オプションの低ランク分解で表現します。参考文献: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859) + +アルゴリズムと推奨設定は[LyCORISのアルゴリズム解説](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md)と[ガイドライン](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md)を参照してください。 + +LinearおよびConv2d層の両方を対象としています。Conv2d 1x1層はLinear層と同様に扱われます。Conv2d 3x3+層については、オプションのTucker分解またはflat(カーネル平坦化)モードが利用可能です。 + +この機能は実験的なものです。 + +
+ +## Acknowledgments / 謝辞 + +The LoHa and LoKr implementations in sd-scripts are based on the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project by [KohakuBlueleaf](https://github.com/KohakuBlueleaf). We would like to express our sincere gratitude for the excellent research and open-source contributions that made this implementation possible. + +
+日本語 + +sd-scriptsのLoHaおよびLoKrの実装は、[KohakuBlueleaf](https://github.com/KohakuBlueleaf)氏による[LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS)プロジェクトに基づいています。この実装を可能にしてくださった素晴らしい研究とオープンソースへの貢献に心から感謝いたします。 + +
+ +## Supported architectures / 対応アーキテクチャ + +LoHa and LoKr automatically detect the model architecture and apply appropriate default settings. The following architectures are currently supported: + +- **SDXL**: Targets `Transformer2DModel` for UNet and `CLIPAttention`/`CLIPMLP` for text encoders. Conv2d layers in `ResnetBlock2D`, `Downsample2D`, and `Upsample2D` are also supported when `conv_dim` is specified. No default `exclude_patterns`. +- **Anima**: Targets `Block`, `PatchEmbed`, `TimestepEmbedding`, and `FinalLayer` for DiT, and `Qwen3Attention`/`Qwen3MLP` for the text encoder. Default `exclude_patterns` automatically skips modulation, normalization, embedder, and final_layer modules. + +
+日本語 + +LoHaとLoKrは、モデルのアーキテクチャを自動で検出し、適切なデフォルト設定を適用します。現在、以下のアーキテクチャに対応しています: + +- **SDXL**: UNetの`Transformer2DModel`、テキストエンコーダの`CLIPAttention`/`CLIPMLP`を対象とします。`conv_dim`を指定した場合、`ResnetBlock2D`、`Downsample2D`、`Upsample2D`のConv2d層も対象になります。デフォルトの`exclude_patterns`はありません。 +- **Anima**: DiTの`Block`、`PatchEmbed`、`TimestepEmbedding`、`FinalLayer`、テキストエンコーダの`Qwen3Attention`/`Qwen3MLP`を対象とします。デフォルトの`exclude_patterns`により、modulation、normalization、embedder、final_layerモジュールは自動的にスキップされます。 + +
+ +## Training / 学習 + +To use LoHa or LoKr, change the `--network_module` argument in your training command. All other training options (dataset config, optimizer, etc.) remain the same as LoRA. + +
+日本語 + +LoHaまたはLoKrを使用するには、学習コマンドの `--network_module` 引数を変更します。その他の学習オプション(データセット設定、オプティマイザなど)はLoRAと同じです。 + +
+ +### LoHa (SDXL) + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \ + --pretrained_model_name_or_path path/to/sdxl.safetensors \ + --dataset_config path/to/toml \ + --mixed_precision bf16 --fp8_base \ + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \ + --network_module networks.loha --network_dim 32 --network_alpha 16 \ + --max_train_epochs 16 --save_every_n_epochs 1 \ + --output_dir path/to/output --output_name my-loha +``` + +### LoKr (SDXL) + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \ + --pretrained_model_name_or_path path/to/sdxl.safetensors \ + --dataset_config path/to/toml \ + --mixed_precision bf16 --fp8_base \ + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \ + --network_module networks.lokr --network_dim 32 --network_alpha 16 \ + --max_train_epochs 16 --save_every_n_epochs 1 \ + --output_dir path/to/output --output_name my-lokr +``` + +For Anima, replace `sdxl_train_network.py` with `anima_train_network.py` and use the appropriate model path and options. + +
+日本語 + +Animaの場合は、`sdxl_train_network.py` を `anima_train_network.py` に置き換え、適切なモデルパスとオプションを使用してください。 + +
+ +### Common training options / 共通の学習オプション + +The following `--network_args` options are available for both LoHa and LoKr, same as LoRA: + +| Option | Description | +|---|---| +| `verbose=True` | Display detailed information about the network modules | +| `rank_dropout=0.1` | Apply dropout to the rank dimension during training | +| `module_dropout=0.1` | Randomly skip entire modules during training | +| `exclude_patterns=[r'...']` | Exclude modules matching the regex patterns (in addition to architecture defaults) | +| `include_patterns=[r'...']` | Override excludes: modules matching these regex patterns will be included even if they match `exclude_patterns` | +| `network_reg_lrs=regex1=lr1,regex2=lr2` | Set per-module learning rates using regex patterns | +| `network_reg_dims=regex1=dim1,regex2=dim2` | Set per-module dimensions (rank) using regex patterns | + +
+日本語 + +以下の `--network_args` オプションは、LoRAと同様にLoHaとLoKrの両方で使用できます: + +| オプション | 説明 | +|---|---| +| `verbose=True` | ネットワークモジュールの詳細情報を表示 | +| `rank_dropout=0.1` | 学習時にランク次元にドロップアウトを適用 | +| `module_dropout=0.1` | 学習時にモジュール全体をランダムにスキップ | +| `exclude_patterns=[r'...']` | 正規表現パターンに一致するモジュールを除外(アーキテクチャのデフォルトに追加) | +| `include_patterns=[r'...']` | 正規表現パターンに一致するモジュールのみを対象とする | +| `network_reg_lrs=regex1=lr1,regex2=lr2` | 正規表現パターンでモジュールごとの学習率を設定 | +| `network_reg_dims=regex1=dim1,regex2=dim2` | 正規表現パターンでモジュールごとの次元(ランク)を設定 | + +
+ +### Conv2d support / Conv2dサポート + +By default, LoHa and LoKr target Linear and Conv2d 1x1 layers. To also train Conv2d 3x3+ layers (e.g., in SDXL's ResNet blocks), use the `conv_dim` and `conv_alpha` options: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" +``` + +For Conv2d 3x3+ layers, you can enable Tucker decomposition for more efficient parameter representation: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True" +``` + +- Without `use_tucker`: The kernel dimensions are flattened into the input dimension (flat mode). +- With `use_tucker=True`: A separate Tucker tensor is used to handle the kernel dimensions, which can be more parameter-efficient. + +
+日本語 + +デフォルトでは、LoHaとLoKrはLinearおよびConv2d 1x1層を対象とします。Conv2d 3x3+層(SDXLのResNetブロックなど)も学習するには、`conv_dim`と`conv_alpha`オプションを使用します: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" +``` + +Conv2d 3x3+層に対して、Tucker分解を有効にすることで、より効率的なパラメータ表現が可能です: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True" +``` + +- `use_tucker`なし: カーネル次元が入力次元に平坦化されます(flatモード)。 +- `use_tucker=True`: カーネル次元を扱う別のTuckerテンソルが使用され、よりパラメータ効率が良くなる場合があります。 + +
+ +### LoKr-specific option: `factor` / LoKr固有のオプション: `factor` + +LoKr decomposes weight dimensions using factorization. The `factor` option controls how dimensions are split: + +- `factor=-1` (default): Automatically find balanced factors. For example, dimension 512 is split into (16, 32). +- `factor=N` (positive integer): Force factorization using the specified value. For example, `factor=4` splits dimension 512 into (4, 128). + +```bash +--network_args "factor=4" +``` + +When `network_dim` (rank) is large enough relative to the factorized dimensions, LoKr uses a full matrix instead of a low-rank decomposition for the second factor. A warning will be logged in this case. + +
+日本語 + +LoKrは重みの次元を因数分解して分割します。`factor` オプションでその分割方法を制御します: + +- `factor=-1`(デフォルト): バランスの良い因数を自動的に見つけます。例えば、次元512は(16, 32)に分割されます。 +- `factor=N`(正の整数): 指定した値で因数分解します。例えば、`factor=4` は次元512を(4, 128)に分割します。 + +```bash +--network_args "factor=4" +``` + +`network_dim`(ランク)が因数分解された次元に対して十分に大きい場合、LoKrは第2因子に低ランク分解ではなくフル行列を使用します。その場合、警告がログに出力されます。 + +
+ +### Anima-specific option: `train_llm_adapter` / Anima固有のオプション: `train_llm_adapter` + +For Anima, you can additionally train the LLM adapter modules by specifying: + +```bash +--network_args "train_llm_adapter=True" +``` + +This includes `LLMAdapterTransformerBlock` modules as training targets. + +
+日本語 + +Animaでは、以下を指定することでLLMアダプターモジュールも追加で学習できます: + +```bash +--network_args "train_llm_adapter=True" +``` + +これにより、`LLMAdapterTransformerBlock` モジュールが学習対象に含まれます。 + +
+ +### LoRA+ / LoRA+ + +LoRA+ (`loraplus_lr_ratio` etc. in `--network_args`) is supported with LoHa/LoKr. For LoHa, the second pair of matrices (`hada_w2_a`) is treated as the "plus" (higher learning rate) parameter group. For LoKr, the scale factor (`lokr_w1`) is treated as the "plus" parameter group. + +```bash +--network_args "loraplus_lr_ratio=4" +``` + +This feature has been confirmed to work in basic testing, but feedback is welcome. If you encounter any issues, please report them. + +
+日本語 + +LoRA+(`--network_args` の `loraplus_lr_ratio` 等)はLoHa/LoKrでもサポートされています。LoHaでは第2ペアの行列(`hada_w2_a`)が「plus」(より高い学習率)パラメータグループとして扱われます。LoKrではスケール係数(`lokr_w1`)が「plus」パラメータグループとして扱われます。 + +```bash +--network_args "loraplus_lr_ratio=4" +``` + +この機能は基本的なテストでは動作確認されていますが、フィードバックをお待ちしています。問題が発生した場合はご報告ください。 + +
+ +## How LoHa and LoKr work / LoHaとLoKrの仕組み + +### LoHa + +LoHa represents the weight update as a Hadamard (element-wise) product of two low-rank matrices: + +``` +ΔW = (W1a × W1b) ⊙ (W2a × W2b) +``` + +where `W1a`, `W1b`, `W2a`, `W2b` are low-rank matrices with rank `network_dim`. This means LoHa has roughly **twice the number of trainable parameters** compared to LoRA at the same rank, but can capture more complex weight structures due to the element-wise product. + +For Conv2d 3x3+ layers with Tucker decomposition, each pair additionally has a Tucker tensor `T` and the reconstruction becomes: `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)`. + +### LoKr + +LoKr represents the weight update using a Kronecker product: + +``` +ΔW = W1 ⊗ W2 (where W2 = W2a × W2b in low-rank mode) +``` + +The original weight dimensions are factorized (e.g., a 512×512 weight might be split so that W1 is 16×16 and W2 is 32×32). W1 is always a full matrix (small), while W2 can be either low-rank decomposed or a full matrix depending on the rank setting. LoKr tends to produce **smaller models** compared to LoRA at the same rank. + +
+日本語 + +### LoHa + +LoHaは重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します: + +``` +ΔW = (W1a × W1b) ⊙ (W2a × W2b) +``` + +ここで `W1a`, `W1b`, `W2a`, `W2b` はランク `network_dim` の低ランク行列です。LoHaは同じランクのLoRAと比較して学習可能なパラメータ数が **約2倍** になりますが、要素ごとの積により、より複雑な重み構造を捉えることができます。 + +Conv2d 3x3+層でTucker分解を使用する場合、各ペアにはさらにTuckerテンソル `T` があり、再構成は `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)` となります。 + +### LoKr + +LoKrはKronecker積を使って重みの更新を表現します: + +``` +ΔW = W1 ⊗ W2 (低ランクモードでは W2 = W2a × W2b) +``` + +元の重みの次元が因数分解されます(例: 512×512の重みが、W1が16×16、W2が32×32に分割されます)。W1は常にフル行列(小さい)で、W2はランク設定に応じて低ランク分解またはフル行列になります。LoKrは同じランクのLoRAと比較して **より小さいモデル** を生成する傾向があります。 + +
+ +## Inference / 推論 + +Trained LoHa/LoKr weights are saved in safetensors format, just like LoRA. + +
+日本語 + +学習済みのLoHa/LoKrの重みは、LoRAと同様にsafetensors形式で保存されます。 + +
+ +### SDXL + +For SDXL, use `gen_img.py` with `--network_module` and `--network_weights`, the same way as LoRA: + +```bash +python gen_img.py --ckpt path/to/sdxl.safetensors \ + --network_module networks.loha --network_weights path/to/loha.safetensors \ + --prompt "your prompt" ... +``` + +Replace `networks.loha` with `networks.lokr` for LoKr weights. + +
+日本語 + +SDXLでは、LoRAと同様に `gen_img.py` で `--network_module` と `--network_weights` を指定します: + +```bash +python gen_img.py --ckpt path/to/sdxl.safetensors \ + --network_module networks.loha --network_weights path/to/loha.safetensors \ + --prompt "your prompt" ... +``` + +LoKrの重みを使用する場合は `networks.loha` を `networks.lokr` に置き換えてください。 + +
+ +### Anima + +For Anima, use `anima_minimal_inference.py` with the `--lora_weight` argument. LoRA, LoHa, and LoKr weights are automatically detected and merged: + +```bash +python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \ + --lora_weight path/to/loha_or_lokr.safetensors ... +``` + +
+日本語 + +Animaでは、`anima_minimal_inference.py` に `--lora_weight` 引数を指定します。LoRA、LoHa、LoKrの重みは自動的に判定されてマージされます: + +```bash +python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \ + --lora_weight path/to/loha_or_lokr.safetensors ... +``` + +
diff --git a/library/lora_utils.py b/library/lora_utils.py index 90e3c389..dadad898 100644 --- a/library/lora_utils.py +++ b/library/lora_utils.py @@ -1,267 +1,287 @@ -import os -import re -from typing import Dict, List, Optional, Union -import torch -from tqdm import tqdm -from library.device_utils import synchronize_device -from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization -from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames -from library.utils import setup_logging - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -def filter_lora_state_dict( - weights_sd: Dict[str, torch.Tensor], - include_pattern: Optional[str] = None, - exclude_pattern: Optional[str] = None, -) -> Dict[str, torch.Tensor]: - # apply include/exclude patterns - original_key_count = len(weights_sd.keys()) - if include_pattern is not None: - regex_include = re.compile(include_pattern) - weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} - logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") - - if exclude_pattern is not None: - original_key_count_ex = len(weights_sd.keys()) - regex_exclude = re.compile(exclude_pattern) - weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} - logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}") - - if len(weights_sd) != original_key_count: - remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) - remaining_keys.sort() - logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") - if len(weights_sd) == 0: - logger.warning("No keys left after filtering.") - - return weights_sd - - -def load_safetensors_with_lora_and_fp8( - model_files: Union[str, List[str]], - lora_weights_list: Optional[List[Dict[str, torch.Tensor]]], - lora_multipliers: Optional[List[float]], - fp8_optimization: bool, - calc_device: torch.device, - move_to_device: bool = False, - dit_weight_dtype: Optional[torch.dtype] = None, - target_keys: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, - disable_numpy_memmap: bool = False, - weight_transform_hooks: Optional[WeightTransformHooks] = None, -) -> dict[str, torch.Tensor]: - """ - Merge LoRA weights into the state dict of a model with fp8 optimization if needed. - - Args: - model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. - lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load. - lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. - fp8_optimization (bool): Whether to apply FP8 optimization. - calc_device (torch.device): Device to calculate on. - move_to_device (bool): Whether to move tensors to the calculation device after loading. - target_keys (Optional[List[str]]): Keys to target for optimization. - exclude_keys (Optional[List[str]]): Keys to exclude from optimization. - disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors. - weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading. - """ - - # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix - if isinstance(model_files, str): - model_files = [model_files] - - extended_model_files = [] - for model_file in model_files: - split_filenames = get_split_weight_filenames(model_file) - if split_filenames is not None: - extended_model_files.extend(split_filenames) - else: - extended_model_files.append(model_file) - model_files = extended_model_files - logger.info(f"Loading model files: {model_files}") - - # load LoRA weights - weight_hook = None - if lora_weights_list is None or len(lora_weights_list) == 0: - lora_weights_list = [] - lora_multipliers = [] - list_of_lora_weight_keys = [] - else: - list_of_lora_weight_keys = [] - for lora_sd in lora_weights_list: - lora_weight_keys = set(lora_sd.keys()) - list_of_lora_weight_keys.append(lora_weight_keys) - - if lora_multipliers is None: - lora_multipliers = [1.0] * len(lora_weights_list) - while len(lora_multipliers) < len(lora_weights_list): - lora_multipliers.append(1.0) - if len(lora_multipliers) > len(lora_weights_list): - lora_multipliers = lora_multipliers[: len(lora_weights_list)] - - # Merge LoRA weights into the state dict - logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") - - # make hook for LoRA merging - def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False): - nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device - - if not model_weight_key.endswith(".weight"): - return model_weight - - original_device = model_weight.device - if original_device != calc_device: - model_weight = model_weight.to(calc_device) # to make calculation faster - - for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers): - # check if this weight has LoRA weights - lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" - found = False - for prefix in ["lora_unet_", ""]: - lora_name = prefix + lora_name_without_prefix.replace(".", "_") - down_key = lora_name + ".lora_down.weight" - up_key = lora_name + ".lora_up.weight" - alpha_key = lora_name + ".alpha" - if down_key in lora_weight_keys and up_key in lora_weight_keys: - found = True - break - if not found: - continue # no LoRA weights for this model weight - - # get LoRA weights - down_weight = lora_sd[down_key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - down_weight = down_weight.to(calc_device) - up_weight = up_weight.to(calc_device) - - original_dtype = model_weight.dtype - if original_dtype.itemsize == 1: # fp8 - # temporarily convert to float16 for calculation - model_weight = model_weight.to(torch.float16) - down_weight = down_weight.to(torch.float16) - up_weight = up_weight.to(torch.float16) - - # W <- W + U * D - if len(model_weight.size()) == 2: - # linear - if len(up_weight.size()) == 4: # use linear projection mismatch - up_weight = up_weight.squeeze(3).squeeze(2) - down_weight = down_weight.squeeze(3).squeeze(2) - model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - model_weight = ( - model_weight - + multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - model_weight = model_weight + multiplier * conved * scale - - if original_dtype.itemsize == 1: # fp8 - model_weight = model_weight.to(original_dtype) # convert back to original dtype - - # remove LoRA keys from set - lora_weight_keys.remove(down_key) - lora_weight_keys.remove(up_key) - if alpha_key in lora_weight_keys: - lora_weight_keys.remove(alpha_key) - - if not keep_on_calc_device and original_device != calc_device: - model_weight = model_weight.to(original_device) # move back to original device - return model_weight - - weight_hook = weight_hook_func - - state_dict = load_safetensors_with_fp8_optimization_and_hook( - model_files, - fp8_optimization, - calc_device, - move_to_device, - dit_weight_dtype, - target_keys, - exclude_keys, - weight_hook=weight_hook, - disable_numpy_memmap=disable_numpy_memmap, - weight_transform_hooks=weight_transform_hooks, - ) - - for lora_weight_keys in list_of_lora_weight_keys: - # check if all LoRA keys are used - if len(lora_weight_keys) > 0: - # if there are still LoRA keys left, it means they are not used in the model - # this is a warning, not an error - logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}") - - return state_dict - - -def load_safetensors_with_fp8_optimization_and_hook( - model_files: list[str], - fp8_optimization: bool, - calc_device: torch.device, - move_to_device: bool = False, - dit_weight_dtype: Optional[torch.dtype] = None, - target_keys: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, - weight_hook: callable = None, - disable_numpy_memmap: bool = False, - weight_transform_hooks: Optional[WeightTransformHooks] = None, -) -> dict[str, torch.Tensor]: - """ - Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. - """ - if fp8_optimization: - logger.info( - f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" - ) - # dit_weight_dtype is not used because we use fp8 optimization - state_dict = load_safetensors_with_fp8_optimization( - model_files, - calc_device, - target_keys, - exclude_keys, - move_to_device=move_to_device, - weight_hook=weight_hook, - disable_numpy_memmap=disable_numpy_memmap, - weight_transform_hooks=weight_transform_hooks, - ) - else: - logger.info( - f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" - ) - state_dict = {} - for model_file in model_files: - with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f: - f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f - for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): - if weight_hook is None and move_to_device: - value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) - else: - value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer - if weight_hook is not None: - value = weight_hook(key, value, keep_on_calc_device=move_to_device) - if move_to_device: - value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) - elif dit_weight_dtype is not None: - value = value.to(dit_weight_dtype) - - state_dict[key] = value - if move_to_device: - synchronize_device(calc_device) - - return state_dict +import os +import re +from typing import Dict, List, Optional, Union +import torch +from tqdm import tqdm +from library.device_utils import synchronize_device +from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization +from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames +from networks.loha import merge_weights_to_tensor as loha_merge +from networks.lokr import merge_weights_to_tensor as lokr_merge + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def filter_lora_state_dict( + weights_sd: Dict[str, torch.Tensor], + include_pattern: Optional[str] = None, + exclude_pattern: Optional[str] = None, +) -> Dict[str, torch.Tensor]: + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if include_pattern is not None: + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + + if exclude_pattern is not None: + original_key_count_ex = len(weights_sd.keys()) + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}") + + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning("No keys left after filtering.") + + return weights_sd + + +def load_safetensors_with_lora_and_fp8( + model_files: Union[str, List[str]], + lora_weights_list: Optional[List[Dict[str, torch.Tensor]]], + lora_multipliers: Optional[List[float]], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + disable_numpy_memmap: bool = False, + weight_transform_hooks: Optional[WeightTransformHooks] = None, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model with fp8 optimization if needed. + + Args: + model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. + lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load. + lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. + fp8_optimization (bool): Whether to apply FP8 optimization. + calc_device (torch.device): Device to calculate on. + move_to_device (bool): Whether to move tensors to the calculation device after loading. + dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization. + target_keys (Optional[List[str]]): Keys to target for optimization. + exclude_keys (Optional[List[str]]): Keys to exclude from optimization. + disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors. + weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading. + """ + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + if isinstance(model_files, str): + model_files = [model_files] + + extended_model_files = [] + for model_file in model_files: + split_filenames = get_split_weight_filenames(model_file) + if split_filenames is not None: + extended_model_files.extend(split_filenames) + else: + extended_model_files.append(model_file) + model_files = extended_model_files + logger.info(f"Loading model files: {model_files}") + + # load LoRA weights + weight_hook = None + if lora_weights_list is None or len(lora_weights_list) == 0: + lora_weights_list = [] + lora_multipliers = [] + list_of_lora_weight_keys = [] + else: + list_of_lora_weight_keys = [] + for lora_sd in lora_weights_list: + lora_weight_keys = set(lora_sd.keys()) + list_of_lora_weight_keys.append(lora_weight_keys) + + if lora_multipliers is None: + lora_multipliers = [1.0] * len(lora_weights_list) + while len(lora_multipliers) < len(lora_weights_list): + lora_multipliers.append(1.0) + if len(lora_multipliers) > len(lora_weights_list): + lora_multipliers = lora_multipliers[: len(lora_weights_list)] + + # Merge LoRA weights into the state dict + logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") + + # make hook for LoRA merging + def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False): + nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device + + if not model_weight_key.endswith(".weight"): + return model_weight + + original_device = model_weight.device + if original_device != calc_device: + model_weight = model_weight.to(calc_device) # to make calculation faster + + for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers): + # check if this weight has LoRA weights + lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" + found = False + for prefix in ["lora_unet_", ""]: + lora_name = prefix + lora_name_without_prefix.replace(".", "_") + down_key = lora_name + ".lora_down.weight" + up_key = lora_name + ".lora_up.weight" + alpha_key = lora_name + ".alpha" + if down_key in lora_weight_keys and up_key in lora_weight_keys: + found = True + break + + if found: + # Standard LoRA merge + # get LoRA weights + down_weight = lora_sd[down_key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + down_weight = down_weight.to(calc_device) + up_weight = up_weight.to(calc_device) + + original_dtype = model_weight.dtype + if original_dtype.itemsize == 1: # fp8 + # temporarily convert to float16 for calculation + model_weight = model_weight.to(torch.float16) + down_weight = down_weight.to(torch.float16) + up_weight = up_weight.to(torch.float16) + + # W <- W + U * D + if len(model_weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + model_weight = ( + model_weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + model_weight = model_weight + multiplier * conved * scale + + if original_dtype.itemsize == 1: # fp8 + model_weight = model_weight.to(original_dtype) # convert back to original dtype + + # remove LoRA keys from set + lora_weight_keys.remove(down_key) + lora_weight_keys.remove(up_key) + if alpha_key in lora_weight_keys: + lora_weight_keys.remove(alpha_key) + continue + + # Check for LoHa/LoKr weights with same prefix search + for prefix in ["lora_unet_", ""]: + lora_name = prefix + lora_name_without_prefix.replace(".", "_") + hada_key = lora_name + ".hada_w1_a" + lokr_key = lora_name + ".lokr_w1" + + if hada_key in lora_weight_keys: + # LoHa merge + model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device) + break + elif lokr_key in lora_weight_keys: + # LoKr merge + model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device) + break + + if not keep_on_calc_device and original_device != calc_device: + model_weight = model_weight.to(original_device) # move back to original device + return model_weight + + weight_hook = weight_hook_func + + state_dict = load_safetensors_with_fp8_optimization_and_hook( + model_files, + fp8_optimization, + calc_device, + move_to_device, + dit_weight_dtype, + target_keys, + exclude_keys, + weight_hook=weight_hook, + disable_numpy_memmap=disable_numpy_memmap, + weight_transform_hooks=weight_transform_hooks, + ) + + for lora_weight_keys in list_of_lora_weight_keys: + # check if all LoRA keys are used + if len(lora_weight_keys) > 0: + # if there are still LoRA keys left, it means they are not used in the model + # this is a warning, not an error + logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}") + + return state_dict + + +def load_safetensors_with_fp8_optimization_and_hook( + model_files: list[str], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + weight_hook: callable = None, + disable_numpy_memmap: bool = False, + weight_transform_hooks: Optional[WeightTransformHooks] = None, +) -> dict[str, torch.Tensor]: + """ + Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. + """ + if fp8_optimization: + logger.info( + f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + # dit_weight_dtype is not used because we use fp8 optimization + state_dict = load_safetensors_with_fp8_optimization( + model_files, + calc_device, + target_keys, + exclude_keys, + move_to_device=move_to_device, + weight_hook=weight_hook, + disable_numpy_memmap=disable_numpy_memmap, + weight_transform_hooks=weight_transform_hooks, + ) + else: + logger.info( + f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f: + f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f + for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): + if weight_hook is None and move_to_device: + value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) + else: + value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer + if weight_hook is not None: + value = weight_hook(key, value, keep_on_calc_device=move_to_device) + if move_to_device: + value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) + elif dit_weight_dtype is not None: + value = value.to(dit_weight_dtype) + + state_dict[key] = value + if move_to_device: + synchronize_device(calc_device) + + return state_dict diff --git a/networks/loha.py b/networks/loha.py new file mode 100644 index 00000000..8734f9c5 --- /dev/null +++ b/networks/loha.py @@ -0,0 +1,643 @@ +# LoHa (Low-rank Hadamard Product) network module +# Reference: https://arxiv.org/abs/2108.06098 +# +# Based on the LyCORIS project by KohakuBlueleaf +# https://github.com/KohakuBlueleaf/LyCORIS + +import ast +import os +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs +from library.utils import setup_logging + +setup_logging() +logger = logging.getLogger(__name__) + + +class HadaWeight(torch.autograd.Function): + """Efficient Hadamard product forward/backward for LoHa. + + Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward + that recomputes intermediates instead of storing them. + """ + + @staticmethod + def forward(ctx, w1a, w1b, w2a, w2b, scale=None): + if scale is None: + scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype) + ctx.save_for_backward(w1a, w1b, w2a, w2b, scale) + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2a @ w2b) + grad_w1a = temp @ w1b.T + grad_w1b = w1a.T @ temp + + temp = grad_out * (w1a @ w1b) + grad_w2a = temp @ w2b.T + grad_w2b = w2a.T @ temp + + del temp + return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None + + +class HadaWeightTucker(torch.autograd.Function): + """Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+. + + Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale + where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa). + Compatible with LyCORIS parameter naming convention. + """ + + @staticmethod + def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None): + if scale is None: + scale = torch.tensor(1, device=t1.device, dtype=t1.dtype) + ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + # Gradients for w1a, w1b, t1 (using rebuild2) + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T) + del grad_w, temp + + grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T) + del grad_temp + + # Gradients for w2a, w2b, t2 (using rebuild1) + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T) + del grad_w, temp + + grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T) + del grad_temp + + return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None + + +class LoHaModule(torch.nn.Module): + """LoHa module for training. Replaces forward method of the original Linear/Conv2d.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + use_tucker=False, + **kwargs, + ): + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + is_conv2d = org_module.__class__.__name__ == "Conv2d" + if is_conv2d: + in_dim = org_module.in_channels + out_dim = org_module.out_channels + kernel_size = org_module.kernel_size + self.is_conv = True + self.stride = org_module.stride + self.padding = org_module.padding + self.dilation = org_module.dilation + self.groups = org_module.groups + self.kernel_size = kernel_size + + self.tucker = use_tucker and any(k != 1 for k in kernel_size) + + if kernel_size == (1, 1): + self.conv_mode = "1x1" + elif self.tucker: + self.conv_mode = "tucker" + else: + self.conv_mode = "flat" + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.is_conv = False + self.tucker = False + self.conv_mode = None + self.kernel_size = None + + self.in_dim = in_dim + self.out_dim = out_dim + + # Create parameters based on mode + if self.conv_mode == "tucker": + # Tucker decomposition for Conv2d 3x3+ + # Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim) + self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size)) + self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size)) + self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + + # LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1) + torch.nn.init.normal_(self.hada_t1, std=0.1) + torch.nn.init.normal_(self.hada_t2, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w1_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + torch.nn.init.normal_(self.hada_w2_a, std=0.1) + elif self.conv_mode == "flat": + # Non-Tucker Conv2d 3x3+: flatten kernel into in_dim + k_prod = 1 + for k in kernel_size: + k_prod *= k + flat_in = in_dim * k_prod + + self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in)) + self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in)) + + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w2_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + else: + # Linear or Conv2d 1x1 + self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w2_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def get_diff_weight(self): + """Return materialized weight delta. + + Returns: + - Linear: 2D tensor (out_dim, in_dim) + - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d + - Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2) + - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) + """ + if self.tucker: + scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device) + return HadaWeightTucker.apply( + self.hada_t1, self.hada_w1_b, self.hada_w1_a, + self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale + ) + elif self.conv_mode == "flat": + scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) + diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size) + else: + scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) + return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + diff_weight = self.get_diff_weight() + + # rank dropout (applied on output dimension) + if self.rank_dropout is not None and self.training: + drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype) + drop = drop.view(-1, *([1] * (diff_weight.dim() - 1))) + diff_weight = diff_weight * drop + scale = 1.0 / (1.0 - self.rank_dropout) + else: + scale = 1.0 + + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoHaInfModule(LoHaModule): + """LoHa module for inference. Supports merge_to and get_weight.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference; pass use_tucker from kwargs + use_tucker = kwargs.pop("use_tucker", False) + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker) + + self.org_module_ref = [org_module] + self.enabled = True + self.network: AdditionalNetwork = None + + def set_network(self, network): + self.network = network + + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get LoHa weights + w1a = sd["hada_w1_a"].to(torch.float).to(device) + w1b = sd["hada_w1_b"].to(torch.float).to(device) + w2a = sd["hada_w2_a"].to(torch.float).to(device) + w2b = sd["hada_w2_b"].to(torch.float).to(device) + + if self.tucker: + # Tucker mode + t1 = sd["hada_t1"].to(torch.float).to(device) + t2 = sd["hada_t2"].to(torch.float).to(device) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + diff_weight = rebuild1 * rebuild2 * self.scale + else: + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale + # reshape diff_weight to match original weight shape if needed + if diff_weight.shape != weight.shape: + diff_weight = diff_weight.reshape(weight.shape) + + weight = weight.to(device) + self.multiplier * diff_weight + + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.tucker: + t1 = self.hada_t1.to(torch.float) + w1a = self.hada_w1_a.to(torch.float) + w1b = self.hada_w1_b.to(torch.float) + t2 = self.hada_t2.to(torch.float) + w2a = self.hada_w2_a.to(torch.float) + w2b = self.hada_w2_b.to(torch.float) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + weight = rebuild1 * rebuild2 * self.scale * multiplier + else: + w1a = self.hada_w1_a.to(torch.float) + w1b = self.hada_w1_b.to(torch.float) + w2a = self.hada_w2_a.to(torch.float) + w2b = self.hada_w2_b.to(torch.float) + weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier + + if self.is_conv: + if self.conv_mode == "1x1": + weight = weight.unsqueeze(2).unsqueeze(3) + elif self.conv_mode == "flat": + weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size) + + return weight + + def default_forward(self, x): + diff_weight = self.get_diff_weight() + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return self.org_forward(x) + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier + else: + return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae, + text_encoder, + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + """Create a LoHa network. Called by train_network.py via network_module.create_network().""" + if network_dim is None: + network_dim = 4 + if network_alpha is None: + network_alpha = 1.0 + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + # train LLM adapter + train_llm_adapter = kwargs.get("train_llm_adapter", "false") + if train_llm_adapter is not None: + train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False + + # exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + if not isinstance(exclude_patterns, list): + exclude_patterns = [exclude_patterns] + + # add default exclude patterns from arch config + exclude_patterns.extend(arch_config.default_excludes) + + # include patterns + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None: + include_patterns = ast.literal_eval(include_patterns) + if not isinstance(include_patterns, list): + include_patterns = [include_patterns] + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # conv dim/alpha for Conv2d 3x3 + conv_lora_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_lora_dim is not None: + conv_lora_dim = int(conv_lora_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # Tucker decomposition for Conv2d 3x3 + use_tucker = kwargs.get("use_tucker", "false") + if use_tucker is not None: + use_tucker = True if str(use_tucker).lower() == "true" else False + + # verbose + verbose = kwargs.get("verbose", "false") + if verbose is not None: + verbose = True if str(verbose).lower() == "true" else False + + # regex-specific learning rates / dimensions + network_reg_lrs = kwargs.get("network_reg_lrs", None) + reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None + + network_reg_dims = kwargs.get("network_reg_dims", None) + reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + module_class=LoHaModule, + module_kwargs={"use_tucker": use_tucker}, + conv_lora_dim=conv_lora_dim, + conv_alpha=conv_alpha, + train_llm_adapter=train_llm_adapter, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, + reg_dims=reg_dims, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + # LoRA+ support + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + """Create a LoHa network from saved weights. Called by train_network.py.""" + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # detect dim/alpha from weights + modules_dim = {} + modules_alpha = {} + train_llm_adapter = False + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "hada_w1_b" in key: + dim = value.shape[0] + modules_dim[lora_name] = dim + + if "llm_adapter" in lora_name: + train_llm_adapter = True + + # detect Tucker mode from weights + use_tucker = any("hada_t1" in key for key in weights_sd.keys()) + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + module_class = LoHaInfModule if for_inference else LoHaModule + module_kwargs = {"use_tucker": use_tucker} + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + module_kwargs=module_kwargs, + train_llm_adapter=train_llm_adapter, + ) + return network, weights_sd + + +def merge_weights_to_tensor( + model_weight: torch.Tensor, + lora_name: str, + lora_sd: Dict[str, torch.Tensor], + lora_weight_keys: set, + multiplier: float, + calc_device: torch.device, +) -> torch.Tensor: + """Merge LoHa weights directly into a model weight tensor. + + Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3. + No Module/Network creation needed. Consumed keys are removed from lora_weight_keys. + Returns model_weight unchanged if no matching LoHa keys found. + """ + w1a_key = lora_name + ".hada_w1_a" + w1b_key = lora_name + ".hada_w1_b" + w2a_key = lora_name + ".hada_w2_a" + w2b_key = lora_name + ".hada_w2_b" + t1_key = lora_name + ".hada_t1" + t2_key = lora_name + ".hada_t2" + alpha_key = lora_name + ".alpha" + + if w1a_key not in lora_weight_keys: + return model_weight + + w1a = lora_sd[w1a_key].to(calc_device) + w1b = lora_sd[w1b_key].to(calc_device) + w2a = lora_sd[w2a_key].to(calc_device) + w2b = lora_sd[w2b_key].to(calc_device) + + has_tucker = t1_key in lora_weight_keys + + dim = w1b.shape[0] + alpha = lora_sd.get(alpha_key, torch.tensor(dim)) + if isinstance(alpha, torch.Tensor): + alpha = alpha.item() + scale = alpha / dim + + original_dtype = model_weight.dtype + if original_dtype.itemsize == 1: # fp8 + model_weight = model_weight.to(torch.float16) + w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16) + w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16) + + if has_tucker: + # Tucker decomposition: rebuild via einsum + t1 = lora_sd[t1_key].to(calc_device) + t2 = lora_sd[t2_key].to(calc_device) + if original_dtype.itemsize == 1: + t1, t2 = t1.to(torch.float16), t2.to(torch.float16) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + diff_weight = rebuild1 * rebuild2 * scale + else: + # Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale + + # Reshape diff_weight to match model_weight shape if needed + # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.) + if diff_weight.shape != model_weight.shape: + diff_weight = diff_weight.reshape(model_weight.shape) + + model_weight = model_weight + multiplier * diff_weight + + if original_dtype.itemsize == 1: + model_weight = model_weight.to(original_dtype) + + # remove consumed keys + consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key] + if has_tucker: + consumed.extend([t1_key, t2_key]) + for key in consumed: + lora_weight_keys.discard(key) + + return model_weight diff --git a/networks/lokr.py b/networks/lokr.py new file mode 100644 index 00000000..03b50ca0 --- /dev/null +++ b/networks/lokr.py @@ -0,0 +1,683 @@ +# LoKr (Low-rank Kronecker Product) network module +# Reference: https://arxiv.org/abs/2309.14859 +# +# Based on the LyCORIS project by KohakuBlueleaf +# https://github.com/KohakuBlueleaf/LyCORIS + +import ast +import math +import os +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs +from library.utils import setup_logging + +setup_logging() +logger = logging.getLogger(__name__) + + +def factorization(dimension: int, factor: int = -1) -> tuple: + """Return a tuple of two values whose product equals dimension, + optimized for balanced factors. + + In LoKr, the first value is for the weight scale (smaller), + and the second value is for the weight (larger). + + Examples: + factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32) + factor=4: 128 -> (4, 32), 512 -> (4, 128) + """ + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + + +def make_kron(w1, w2, scale): + """Compute Kronecker product of w1 and w2, scaled by scale.""" + if w1.dim() != w2.dim(): + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + w2 = w2.contiguous() + rebuild = torch.kron(w1, w2) + if scale != 1: + rebuild = rebuild * scale + return rebuild + + +def rebuild_tucker(t, wa, wb): + """Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb). + + Compatible with LyCORIS convention. + """ + return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb) + + +class LoKrModule(torch.nn.Module): + """LoKr module for training. Replaces forward method of the original Linear/Conv2d.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + factor=-1, + use_tucker=False, + **kwargs, + ): + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + is_conv2d = org_module.__class__.__name__ == "Conv2d" + if is_conv2d: + in_dim = org_module.in_channels + out_dim = org_module.out_channels + kernel_size = org_module.kernel_size + self.is_conv = True + self.stride = org_module.stride + self.padding = org_module.padding + self.dilation = org_module.dilation + self.groups = org_module.groups + self.kernel_size = kernel_size + + self.tucker = use_tucker and any(k != 1 for k in kernel_size) + + if kernel_size == (1, 1): + self.conv_mode = "1x1" + elif self.tucker: + self.conv_mode = "tucker" + else: + self.conv_mode = "flat" + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.is_conv = False + self.tucker = False + self.conv_mode = None + self.kernel_size = None + + self.in_dim = in_dim + self.out_dim = out_dim + + factor = int(factor) + self.use_w2 = False + + # Factorize dimensions + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + + # w1 is always a full matrix (the "scale" factor, small) + self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m)) + + # w2: depends on mode + if self.conv_mode in ("tucker", "flat"): + # Conv2d 3x3+ modes + k_size = kernel_size + + if lora_dim >= max(out_k, in_n) / 2: + # Full matrix mode (includes kernel dimensions) + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size)) + logger.warning( + f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} " + f"and factor={factor}, using full matrix mode for Conv2d." + ) + elif self.tucker: + # Tucker mode: separate kernel into t2 tensor + self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size)) + self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) + else: + # Non-Tucker: flatten kernel into w2_b + k_prod = 1 + for k in k_size: + k_prod *= k + self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod)) + else: + # Linear or Conv2d 1x1 + if lora_dim < max(out_k, in_n) / 2: + self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) + else: + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n)) + if lora_dim >= max(out_k, in_n) / 2: + logger.warning( + f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} " + f"and factor={factor}, using full matrix mode." + ) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + # if both w1 and w2 are full matrices, use scale = 1 + if self.use_w2: + alpha = lora_dim + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # Initialization + torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) + if self.use_w2: + torch.nn.init.constant_(self.lokr_w2, 0) + else: + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) + torch.nn.init.constant_(self.lokr_w2_b, 0) + # Ensures ΔW = kron(w1, 0) = 0 at init + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def get_diff_weight(self): + """Return materialized weight delta. + + Returns: + - Linear: 2D tensor (out_dim, in_dim) + - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d + - Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2) + - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D + """ + w1 = self.lokr_w1 + + if self.use_w2: + w2 = self.lokr_w2 + elif self.tucker: + w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) + else: + w2 = self.lokr_w2_a @ self.lokr_w2_b + + result = make_kron(w1, w2, self.scale) + + # For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D + if self.conv_mode == "flat" and result.dim() == 2: + result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size) + + return result + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + diff_weight = self.get_diff_weight() + + # rank dropout + if self.rank_dropout is not None and self.training: + drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype) + drop = drop.view(-1, *([1] * (diff_weight.dim() - 1))) + diff_weight = diff_weight * drop + scale = 1.0 / (1.0 - self.rank_dropout) + else: + scale = 1.0 + + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoKrInfModule(LoKrModule): + """LoKr module for inference. Supports merge_to and get_weight.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference; pass factor and use_tucker from kwargs + factor = kwargs.pop("factor", -1) + use_tucker = kwargs.pop("use_tucker", False) + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker) + + self.org_module_ref = [org_module] + self.enabled = True + self.network: AdditionalNetwork = None + + def set_network(self, network): + self.network = network + + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get LoKr weights + w1 = sd["lokr_w1"].to(torch.float).to(device) + + if "lokr_w2" in sd: + w2 = sd["lokr_w2"].to(torch.float).to(device) + elif "lokr_t2" in sd: + # Tucker mode + t2 = sd["lokr_t2"].to(torch.float).to(device) + w2a = sd["lokr_w2_a"].to(torch.float).to(device) + w2b = sd["lokr_w2_b"].to(torch.float).to(device) + w2 = rebuild_tucker(t2, w2a, w2b) + else: + w2a = sd["lokr_w2_a"].to(torch.float).to(device) + w2b = sd["lokr_w2_b"].to(torch.float).to(device) + w2 = w2a @ w2b + + # compute ΔW via Kronecker product + diff_weight = make_kron(w1, w2, self.scale) + + # reshape diff_weight to match original weight shape if needed + if diff_weight.shape != weight.shape: + diff_weight = diff_weight.reshape(weight.shape) + + weight = weight.to(device) + self.multiplier * diff_weight + + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + w1 = self.lokr_w1.to(torch.float) + + if self.use_w2: + w2 = self.lokr_w2.to(torch.float) + elif self.tucker: + w2 = rebuild_tucker( + self.lokr_t2.to(torch.float), + self.lokr_w2_a.to(torch.float), + self.lokr_w2_b.to(torch.float), + ) + else: + w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float) + + weight = make_kron(w1, w2, self.scale) * multiplier + + # reshape to match original weight shape if needed + if self.is_conv: + if self.conv_mode == "1x1": + weight = weight.unsqueeze(2).unsqueeze(3) + elif self.conv_mode == "flat" and weight.dim() == 2: + weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size) + # Tucker and full matrix modes: already 4D from kron + + return weight + + def default_forward(self, x): + diff_weight = self.get_diff_weight() + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return self.org_forward(x) + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier + else: + return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae, + text_encoder, + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + """Create a LoKr network. Called by train_network.py via network_module.create_network().""" + if network_dim is None: + network_dim = 4 + if network_alpha is None: + network_alpha = 1.0 + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + # train LLM adapter + train_llm_adapter = kwargs.get("train_llm_adapter", "false") + if train_llm_adapter is not None: + train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False + + # exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + if not isinstance(exclude_patterns, list): + exclude_patterns = [exclude_patterns] + + # add default exclude patterns from arch config + exclude_patterns.extend(arch_config.default_excludes) + + # include patterns + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None: + include_patterns = ast.literal_eval(include_patterns) + if not isinstance(include_patterns, list): + include_patterns = [include_patterns] + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # conv dim/alpha for Conv2d 3x3 + conv_lora_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_lora_dim is not None: + conv_lora_dim = int(conv_lora_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # Tucker decomposition for Conv2d 3x3 + use_tucker = kwargs.get("use_tucker", "false") + if use_tucker is not None: + use_tucker = True if str(use_tucker).lower() == "true" else False + + # factor for LoKr + factor = int(kwargs.get("factor", -1)) + + # verbose + verbose = kwargs.get("verbose", "false") + if verbose is not None: + verbose = True if str(verbose).lower() == "true" else False + + # regex-specific learning rates / dimensions + network_reg_lrs = kwargs.get("network_reg_lrs", None) + reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None + + network_reg_dims = kwargs.get("network_reg_dims", None) + reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + module_class=LoKrModule, + module_kwargs={"factor": factor, "use_tucker": use_tucker}, + conv_lora_dim=conv_lora_dim, + conv_alpha=conv_alpha, + train_llm_adapter=train_llm_adapter, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, + reg_dims=reg_dims, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + # LoRA+ support + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + """Create a LoKr network from saved weights. Called by train_network.py.""" + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # detect dim/alpha from weights + modules_dim = {} + modules_alpha = {} + train_llm_adapter = False + use_tucker = False + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lokr_w2_a" in key: + # low-rank mode: dim detection depends on Tucker vs non-Tucker + if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd: + # Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0] + dim = value.shape[0] + else: + # Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1] + dim = value.shape[1] + modules_dim[lora_name] = dim + elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key: + # full matrix mode: set dim large enough to trigger full-matrix path + if lora_name not in modules_dim: + modules_dim[lora_name] = max(value.shape[0], value.shape[1]) + + if "lokr_t2" in key: + use_tucker = True + + if "llm_adapter" in lora_name: + train_llm_adapter = True + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + # extract factor for LoKr + factor = int(kwargs.get("factor", -1)) + + module_class = LoKrInfModule if for_inference else LoKrModule + module_kwargs = {"factor": factor, "use_tucker": use_tucker} + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + module_kwargs=module_kwargs, + train_llm_adapter=train_llm_adapter, + ) + return network, weights_sd + + +def merge_weights_to_tensor( + model_weight: torch.Tensor, + lora_name: str, + lora_sd: Dict[str, torch.Tensor], + lora_weight_keys: set, + multiplier: float, + calc_device: torch.device, +) -> torch.Tensor: + """Merge LoKr weights directly into a model weight tensor. + + Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3. + No Module/Network creation needed. Consumed keys are removed from lora_weight_keys. + Returns model_weight unchanged if no matching LoKr keys found. + """ + w1_key = lora_name + ".lokr_w1" + w2_key = lora_name + ".lokr_w2" + w2a_key = lora_name + ".lokr_w2_a" + w2b_key = lora_name + ".lokr_w2_b" + t2_key = lora_name + ".lokr_t2" + alpha_key = lora_name + ".alpha" + + if w1_key not in lora_weight_keys: + return model_weight + + w1 = lora_sd[w1_key].to(calc_device) + + # determine mode: full matrix vs Tucker vs low-rank + has_tucker = t2_key in lora_weight_keys + + if w2a_key in lora_weight_keys: + w2a = lora_sd[w2a_key].to(calc_device) + w2b = lora_sd[w2b_key].to(calc_device) + + if has_tucker: + # Tucker: w2a = (rank, out_k), dim = rank + dim = w2a.shape[0] + else: + # Non-Tucker low-rank: w2a = (out_k, rank), dim = rank + dim = w2a.shape[1] + + consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key] + if has_tucker: + consumed_keys.append(t2_key) + elif w2_key in lora_weight_keys: + # full matrix mode + w2a = None + w2b = None + dim = None + consumed_keys = [w1_key, w2_key, alpha_key] + else: + return model_weight + + alpha = lora_sd.get(alpha_key, None) + if alpha is not None and isinstance(alpha, torch.Tensor): + alpha = alpha.item() + + # compute scale + if w2a is not None: + if alpha is None: + alpha = dim + scale = alpha / dim + else: + # full matrix mode: scale = 1.0 + scale = 1.0 + + original_dtype = model_weight.dtype + if original_dtype.itemsize == 1: # fp8 + model_weight = model_weight.to(torch.float16) + w1 = w1.to(torch.float16) + if w2a is not None: + w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16) + + # compute w2 + if w2a is not None: + if has_tucker: + t2 = lora_sd[t2_key].to(calc_device) + if original_dtype.itemsize == 1: + t2 = t2.to(torch.float16) + w2 = rebuild_tucker(t2, w2a, w2b) + else: + w2 = w2a @ w2b + else: + w2 = lora_sd[w2_key].to(calc_device) + if original_dtype.itemsize == 1: + w2 = w2.to(torch.float16) + + # ΔW = kron(w1, w2) * scale + diff_weight = make_kron(w1, w2, scale) + + # Reshape diff_weight to match model_weight shape if needed + # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.) + if diff_weight.shape != model_weight.shape: + diff_weight = diff_weight.reshape(model_weight.shape) + + model_weight = model_weight + multiplier * diff_weight + + if original_dtype.itemsize == 1: + model_weight = model_weight.to(original_dtype) + + # remove consumed keys + for key in consumed_keys: + lora_weight_keys.discard(key) + + return model_weight diff --git a/networks/lora_anima.py b/networks/lora_anima.py index 9413e8c8..4cff2819 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -1,11 +1,11 @@ # LoRA network module for Anima import ast +import math import os import re from typing import Dict, List, Optional, Tuple, Type, Union import torch from library.utils import setup_logging -from networks.lora_flux import LoRAModule, LoRAInfModule import logging @@ -13,6 +13,213 @@ setup_logging() logger = logging.getLogger(__name__) +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if isinstance(self.lora_down, torch.nn.Conv2d): + # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1] + mask = mask.unsqueeze(-1).unsqueeze(-1) + else: + # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim] + for _ in range(len(lx.size()) - 2): + mask = mask.unsqueeze(1) + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + def create_network( multiplier: float, network_dim: Optional[int], diff --git a/networks/network_base.py b/networks/network_base.py new file mode 100644 index 00000000..d9697562 --- /dev/null +++ b/networks/network_base.py @@ -0,0 +1,545 @@ +# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc). +# Provides architecture detection and a generic AdditionalNetwork class. + +import os +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class ArchConfig: + unet_target_modules: List[str] + te_target_modules: List[str] + unet_prefix: str + te_prefixes: List[str] + default_excludes: List[str] = field(default_factory=list) + adapter_target_modules: List[str] = field(default_factory=list) + unet_conv_target_modules: List[str] = field(default_factory=list) + + +def detect_arch_config(unet, text_encoders) -> ArchConfig: + """Detect architecture from model structure and return ArchConfig.""" + from library.sdxl_original_unet import SdxlUNet2DConditionModel + + # Check SDXL first + if unet is not None and ( + issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel) + ): + return ArchConfig( + unet_target_modules=["Transformer2DModel"], + te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"], + unet_prefix="lora_unet", + te_prefixes=["lora_te1", "lora_te2"], + default_excludes=[], + unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"], + ) + + # Check Anima: look for Block class in named_modules + module_class_names = set() + if unet is not None: + for module in unet.modules(): + module_class_names.add(type(module).__name__) + + if "Block" in module_class_names: + return ArchConfig( + unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"], + te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"], + unet_prefix="lora_unet", + te_prefixes=["lora_te"], + default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"], + adapter_target_modules=["LLMAdapterTransformerBlock"], + ) + + raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}") + + +def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]: + """Parse a string of key-value pairs separated by commas.""" + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + +class AdditionalNetwork(torch.nn.Module): + """Generic Additional network that supports LoHa, LoKr, and similar module types. + + Constructed with a module_class parameter to inject the specific module type. + Based on the lora_anima.py LoRANetwork, generalized for multiple architectures. + """ + + def __init__( + self, + text_encoders: list, + unet, + arch_config: ArchConfig, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + module_class: Type[torch.nn.Module] = None, + module_kwargs: Optional[Dict] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, + reg_dims: Optional[Dict[str, int]] = None, + reg_lrs: Optional[Dict[str, float]] = None, + train_llm_adapter: bool = False, + verbose: bool = False, + ) -> None: + super().__init__() + assert module_class is not None, "module_class must be specified" + + self.multiplier = multiplier + self.lora_dim = lora_dim + self.alpha = alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.train_llm_adapter = train_llm_adapter + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs + self.arch_config = arch_config + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if module_kwargs is None: + module_kwargs = {} + + if modules_dim is not None: + logger.info(f"create {module_class.__name__} network from weights") + else: + logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + + # compile regular expressions + def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]: + re_patterns = [] + if patterns is not None: + for pattern in patterns: + try: + re_pattern = re.compile(pattern) + except re.error as e: + logger.error(f"Invalid pattern '{pattern}': {e}") + continue + re_patterns.append(re_pattern) + return re_patterns + + exclude_re_patterns = str_to_re_patterns(exclude_patterns) + include_re_patterns = str_to_re_patterns(include_patterns) + + # create module instances + def create_modules( + prefix: str, + root_module: torch.nn.Module, + target_replace_modules: List[str], + default_dim: Optional[int] = None, + ) -> Tuple[List[torch.nn.Module], List[str]]: + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: + module = root_module + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + original_name = (name + "." if name else "") + child_name + lora_name = f"{prefix}.{original_name}".replace(".", "_") + + # exclude/include filter + excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns) + included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns) + if excluded and not included: + if verbose: + logger.info(f"exclude: {original_name}") + continue + + dim = None + alpha_val = None + + if modules_dim is not None: + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha_val = modules_alpha[lora_name] + else: + if self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.fullmatch(reg, original_name): + dim = d + alpha_val = self.alpha + logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}") + break + # fallback to default dim + if dim is None: + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha_val = self.alpha + elif is_conv2d and self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha_val = self.conv_alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1: + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha_val, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + **module_kwargs, + ) + lora.original_name = original_name + loras.append(lora) + + if target_replace_modules is None: + break + return loras, skipped + + # Create modules for text encoders + self.text_encoder_loras: List[torch.nn.Module] = [] + skipped_te = [] + if text_encoders is not None: + for i, text_encoder in enumerate(text_encoders): + if text_encoder is None: + continue + + # Determine prefix for this text encoder + if i < len(arch_config.te_prefixes): + te_prefix = arch_config.te_prefixes[i] + else: + te_prefix = arch_config.te_prefixes[0] + + logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):") + te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules) + logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.") + self.text_encoder_loras.extend(te_loras) + skipped_te += te_skipped + + # Create modules for UNet/DiT + target_modules = list(arch_config.unet_target_modules) + if modules_dim is not None or conv_lora_dim is not None: + target_modules.extend(arch_config.unet_conv_target_modules) + if train_llm_adapter and arch_config.adapter_target_modules: + target_modules.extend(arch_config.adapter_target_modules) + + self.unet_loras: List[torch.nn.Module] + self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules) + logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.") + + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:") + for name in skipped: + logger.info(f"\t{name}") + + # assertion: no duplicate names + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def is_mergeable(self): + return True + + def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + te_prefixes = self.arch_config.te_prefixes + unet_prefix = self.arch_config.unet_prefix + + for key in weights_sd.keys(): + if any(key.startswith(p) for p in te_prefixes): + apply_text_encoder = True + elif key.startswith(unet_prefix): + apply_unet = True + + if apply_text_encoder: + logger.info("enable modules for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable modules for UNet/DiT") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info("weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + pass # already a list with one element + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + reg_groups = {} + reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else [] + + for lora in loras: + matched_reg_lr = None + for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): + if re.fullmatch(regex_str, lora.original_name): + matched_reg_lr = (i, reg_lr) + logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}") + break + + for name, param in lora.named_parameters(): + if matched_reg_lr is not None: + reg_idx, reg_lr = matched_reg_lr + group_key = f"reg_lr_{reg_idx}" + if group_key not in reg_groups: + reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr} + # LoRA+ detection: check for "up" weight parameters + if loraplus_ratio is not None and self._is_plus_param(name): + reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param + else: + reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param + continue + + if loraplus_ratio is not None and self._is_plus_param(name): + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for group_key, group in reg_groups.items(): + reg_lr = group["lr"] + for key in ("lora", "plus"): + param_data = {"params": group[key].values()} + if len(param_data["params"]) == 0: + continue + if key == "plus": + param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr + else: + param_data["lr"] = reg_lr + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + params.append(param_data) + desc = f"reg_lr_{group_key.split('_')[-1]}" + descriptions.append(desc + (" plus" if key == "plus" else "")) + + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if len(param_data["params"]) == 0: + continue + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + return params, descriptions + + if self.text_encoder_loras: + loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + # Group TE loras by prefix + for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes): + te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)] + if len(te_loras) > 0: + te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0] + logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}") + params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio) + all_params.extend(params) + lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def _is_plus_param(self, name: str) -> bool: + """Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+. + + For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor). + Override in subclass if needed. Default: check for common 'up' patterns. + """ + return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name + + def enable_gradient_checkpointing(self): + pass # not supported + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + loras = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + loras = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + loras = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False