diff --git a/.ai/context/01-overview.md b/.ai/context/01-overview.md
index 41133e98..c37aba19 100644
--- a/.ai/context/01-overview.md
+++ b/.ai/context/01-overview.md
@@ -21,6 +21,9 @@ Each supported model family has a consistent structure:
- **SDXL**: `sdxl_train*.py`, `library/sdxl_*`
- **SD3**: `sd3_train*.py`, `library/sd3_*`
- **FLUX.1**: `flux_train*.py`, `library/flux_*`
+- **Lumina Image 2.0**: `lumina_train*.py`, `library/lumina_*`
+- **HunyuanImage-2.1**: `hunyuan_image_train*.py`, `library/hunyuan_image_*`
+- **Anima-Preview**: `anima_train*.py`, `library/anima_*`
### Key Components
diff --git a/.gitignore b/.gitignore
index cfdc0268..f5772a7f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
+references
\ No newline at end of file
diff --git a/docs/loha_lokr.md b/docs/loha_lokr.md
new file mode 100644
index 00000000..6f16ba66
--- /dev/null
+++ b/docs/loha_lokr.md
@@ -0,0 +1,359 @@
+> 📝 Click on the language section to expand / 言語をクリックして展開
+
+# LoHa / LoKr (LyCORIS)
+
+## Overview / 概要
+
+In addition to standard LoRA, sd-scripts supports **LoHa** (Low-rank Hadamard Product) and **LoKr** (Low-rank Kronecker Product) as alternative parameter-efficient fine-tuning methods. These are based on techniques from the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project.
+
+- **LoHa**: Represents weight updates as a Hadamard (element-wise) product of two low-rank matrices. Reference: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
+- **LoKr**: Represents weight updates as a Kronecker product with optional low-rank decomposition. Reference: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
+
+The algorithms and recommended settings are described in the [LyCORIS documentation](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md) and [guidelines](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md).
+
+Both methods target Linear and Conv2d layers. Conv2d 1x1 layers are treated similarly to Linear layers. For Conv2d 3x3+ layers, optional Tucker decomposition or flat (kernel-flattened) mode is available.
+
+This feature is experimental.
+
+
+日本語
+
+sd-scriptsでは、標準的なLoRAに加え、代替のパラメータ効率の良いファインチューニング手法として **LoHa**(Low-rank Hadamard Product)と **LoKr**(Low-rank Kronecker Product)をサポートしています。これらは [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) プロジェクトの手法に基づいています。
+
+- **LoHa**: 重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します。参考文献: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098)
+- **LoKr**: 重みの更新をKronecker積と、オプションの低ランク分解で表現します。参考文献: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859)
+
+アルゴリズムと推奨設定は[LyCORISのアルゴリズム解説](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md)と[ガイドライン](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md)を参照してください。
+
+LinearおよびConv2d層の両方を対象としています。Conv2d 1x1層はLinear層と同様に扱われます。Conv2d 3x3+層については、オプションのTucker分解またはflat(カーネル平坦化)モードが利用可能です。
+
+この機能は実験的なものです。
+
+
+
+## Acknowledgments / 謝辞
+
+The LoHa and LoKr implementations in sd-scripts are based on the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project by [KohakuBlueleaf](https://github.com/KohakuBlueleaf). We would like to express our sincere gratitude for the excellent research and open-source contributions that made this implementation possible.
+
+
+日本語
+
+sd-scriptsのLoHaおよびLoKrの実装は、[KohakuBlueleaf](https://github.com/KohakuBlueleaf)氏による[LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS)プロジェクトに基づいています。この実装を可能にしてくださった素晴らしい研究とオープンソースへの貢献に心から感謝いたします。
+
+
+
+## Supported architectures / 対応アーキテクチャ
+
+LoHa and LoKr automatically detect the model architecture and apply appropriate default settings. The following architectures are currently supported:
+
+- **SDXL**: Targets `Transformer2DModel` for UNet and `CLIPAttention`/`CLIPMLP` for text encoders. Conv2d layers in `ResnetBlock2D`, `Downsample2D`, and `Upsample2D` are also supported when `conv_dim` is specified. No default `exclude_patterns`.
+- **Anima**: Targets `Block`, `PatchEmbed`, `TimestepEmbedding`, and `FinalLayer` for DiT, and `Qwen3Attention`/`Qwen3MLP` for the text encoder. Default `exclude_patterns` automatically skips modulation, normalization, embedder, and final_layer modules.
+
+
+日本語
+
+LoHaとLoKrは、モデルのアーキテクチャを自動で検出し、適切なデフォルト設定を適用します。現在、以下のアーキテクチャに対応しています:
+
+- **SDXL**: UNetの`Transformer2DModel`、テキストエンコーダの`CLIPAttention`/`CLIPMLP`を対象とします。`conv_dim`を指定した場合、`ResnetBlock2D`、`Downsample2D`、`Upsample2D`のConv2d層も対象になります。デフォルトの`exclude_patterns`はありません。
+- **Anima**: DiTの`Block`、`PatchEmbed`、`TimestepEmbedding`、`FinalLayer`、テキストエンコーダの`Qwen3Attention`/`Qwen3MLP`を対象とします。デフォルトの`exclude_patterns`により、modulation、normalization、embedder、final_layerモジュールは自動的にスキップされます。
+
+
+
+## Training / 学習
+
+To use LoHa or LoKr, change the `--network_module` argument in your training command. All other training options (dataset config, optimizer, etc.) remain the same as LoRA.
+
+
+日本語
+
+LoHaまたはLoKrを使用するには、学習コマンドの `--network_module` 引数を変更します。その他の学習オプション(データセット設定、オプティマイザなど)はLoRAと同じです。
+
+
+
+### LoHa (SDXL)
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
+ --pretrained_model_name_or_path path/to/sdxl.safetensors \
+ --dataset_config path/to/toml \
+ --mixed_precision bf16 --fp8_base \
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
+ --network_module networks.loha --network_dim 32 --network_alpha 16 \
+ --max_train_epochs 16 --save_every_n_epochs 1 \
+ --output_dir path/to/output --output_name my-loha
+```
+
+### LoKr (SDXL)
+
+```bash
+accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \
+ --pretrained_model_name_or_path path/to/sdxl.safetensors \
+ --dataset_config path/to/toml \
+ --mixed_precision bf16 --fp8_base \
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \
+ --network_module networks.lokr --network_dim 32 --network_alpha 16 \
+ --max_train_epochs 16 --save_every_n_epochs 1 \
+ --output_dir path/to/output --output_name my-lokr
+```
+
+For Anima, replace `sdxl_train_network.py` with `anima_train_network.py` and use the appropriate model path and options.
+
+
+日本語
+
+Animaの場合は、`sdxl_train_network.py` を `anima_train_network.py` に置き換え、適切なモデルパスとオプションを使用してください。
+
+
+
+### Common training options / 共通の学習オプション
+
+The following `--network_args` options are available for both LoHa and LoKr, same as LoRA:
+
+| Option | Description |
+|---|---|
+| `verbose=True` | Display detailed information about the network modules |
+| `rank_dropout=0.1` | Apply dropout to the rank dimension during training |
+| `module_dropout=0.1` | Randomly skip entire modules during training |
+| `exclude_patterns=[r'...']` | Exclude modules matching the regex patterns (in addition to architecture defaults) |
+| `include_patterns=[r'...']` | Override excludes: modules matching these regex patterns will be included even if they match `exclude_patterns` |
+| `network_reg_lrs=regex1=lr1,regex2=lr2` | Set per-module learning rates using regex patterns |
+| `network_reg_dims=regex1=dim1,regex2=dim2` | Set per-module dimensions (rank) using regex patterns |
+
+
+日本語
+
+以下の `--network_args` オプションは、LoRAと同様にLoHaとLoKrの両方で使用できます:
+
+| オプション | 説明 |
+|---|---|
+| `verbose=True` | ネットワークモジュールの詳細情報を表示 |
+| `rank_dropout=0.1` | 学習時にランク次元にドロップアウトを適用 |
+| `module_dropout=0.1` | 学習時にモジュール全体をランダムにスキップ |
+| `exclude_patterns=[r'...']` | 正規表現パターンに一致するモジュールを除外(アーキテクチャのデフォルトに追加) |
+| `include_patterns=[r'...']` | 正規表現パターンに一致するモジュールのみを対象とする |
+| `network_reg_lrs=regex1=lr1,regex2=lr2` | 正規表現パターンでモジュールごとの学習率を設定 |
+| `network_reg_dims=regex1=dim1,regex2=dim2` | 正規表現パターンでモジュールごとの次元(ランク)を設定 |
+
+
+
+### Conv2d support / Conv2dサポート
+
+By default, LoHa and LoKr target Linear and Conv2d 1x1 layers. To also train Conv2d 3x3+ layers (e.g., in SDXL's ResNet blocks), use the `conv_dim` and `conv_alpha` options:
+
+```bash
+--network_args "conv_dim=16" "conv_alpha=8"
+```
+
+For Conv2d 3x3+ layers, you can enable Tucker decomposition for more efficient parameter representation:
+
+```bash
+--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
+```
+
+- Without `use_tucker`: The kernel dimensions are flattened into the input dimension (flat mode).
+- With `use_tucker=True`: A separate Tucker tensor is used to handle the kernel dimensions, which can be more parameter-efficient.
+
+
+日本語
+
+デフォルトでは、LoHaとLoKrはLinearおよびConv2d 1x1層を対象とします。Conv2d 3x3+層(SDXLのResNetブロックなど)も学習するには、`conv_dim`と`conv_alpha`オプションを使用します:
+
+```bash
+--network_args "conv_dim=16" "conv_alpha=8"
+```
+
+Conv2d 3x3+層に対して、Tucker分解を有効にすることで、より効率的なパラメータ表現が可能です:
+
+```bash
+--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True"
+```
+
+- `use_tucker`なし: カーネル次元が入力次元に平坦化されます(flatモード)。
+- `use_tucker=True`: カーネル次元を扱う別のTuckerテンソルが使用され、よりパラメータ効率が良くなる場合があります。
+
+
+
+### LoKr-specific option: `factor` / LoKr固有のオプション: `factor`
+
+LoKr decomposes weight dimensions using factorization. The `factor` option controls how dimensions are split:
+
+- `factor=-1` (default): Automatically find balanced factors. For example, dimension 512 is split into (16, 32).
+- `factor=N` (positive integer): Force factorization using the specified value. For example, `factor=4` splits dimension 512 into (4, 128).
+
+```bash
+--network_args "factor=4"
+```
+
+When `network_dim` (rank) is large enough relative to the factorized dimensions, LoKr uses a full matrix instead of a low-rank decomposition for the second factor. A warning will be logged in this case.
+
+
+日本語
+
+LoKrは重みの次元を因数分解して分割します。`factor` オプションでその分割方法を制御します:
+
+- `factor=-1`(デフォルト): バランスの良い因数を自動的に見つけます。例えば、次元512は(16, 32)に分割されます。
+- `factor=N`(正の整数): 指定した値で因数分解します。例えば、`factor=4` は次元512を(4, 128)に分割します。
+
+```bash
+--network_args "factor=4"
+```
+
+`network_dim`(ランク)が因数分解された次元に対して十分に大きい場合、LoKrは第2因子に低ランク分解ではなくフル行列を使用します。その場合、警告がログに出力されます。
+
+
+
+### Anima-specific option: `train_llm_adapter` / Anima固有のオプション: `train_llm_adapter`
+
+For Anima, you can additionally train the LLM adapter modules by specifying:
+
+```bash
+--network_args "train_llm_adapter=True"
+```
+
+This includes `LLMAdapterTransformerBlock` modules as training targets.
+
+
+日本語
+
+Animaでは、以下を指定することでLLMアダプターモジュールも追加で学習できます:
+
+```bash
+--network_args "train_llm_adapter=True"
+```
+
+これにより、`LLMAdapterTransformerBlock` モジュールが学習対象に含まれます。
+
+
+
+### LoRA+ / LoRA+
+
+LoRA+ (`loraplus_lr_ratio` etc. in `--network_args`) is supported with LoHa/LoKr. For LoHa, the second pair of matrices (`hada_w2_a`) is treated as the "plus" (higher learning rate) parameter group. For LoKr, the scale factor (`lokr_w1`) is treated as the "plus" parameter group.
+
+```bash
+--network_args "loraplus_lr_ratio=4"
+```
+
+This feature has been confirmed to work in basic testing, but feedback is welcome. If you encounter any issues, please report them.
+
+
+日本語
+
+LoRA+(`--network_args` の `loraplus_lr_ratio` 等)はLoHa/LoKrでもサポートされています。LoHaでは第2ペアの行列(`hada_w2_a`)が「plus」(より高い学習率)パラメータグループとして扱われます。LoKrではスケール係数(`lokr_w1`)が「plus」パラメータグループとして扱われます。
+
+```bash
+--network_args "loraplus_lr_ratio=4"
+```
+
+この機能は基本的なテストでは動作確認されていますが、フィードバックをお待ちしています。問題が発生した場合はご報告ください。
+
+
+
+## How LoHa and LoKr work / LoHaとLoKrの仕組み
+
+### LoHa
+
+LoHa represents the weight update as a Hadamard (element-wise) product of two low-rank matrices:
+
+```
+ΔW = (W1a × W1b) ⊙ (W2a × W2b)
+```
+
+where `W1a`, `W1b`, `W2a`, `W2b` are low-rank matrices with rank `network_dim`. This means LoHa has roughly **twice the number of trainable parameters** compared to LoRA at the same rank, but can capture more complex weight structures due to the element-wise product.
+
+For Conv2d 3x3+ layers with Tucker decomposition, each pair additionally has a Tucker tensor `T` and the reconstruction becomes: `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)`.
+
+### LoKr
+
+LoKr represents the weight update using a Kronecker product:
+
+```
+ΔW = W1 ⊗ W2 (where W2 = W2a × W2b in low-rank mode)
+```
+
+The original weight dimensions are factorized (e.g., a 512×512 weight might be split so that W1 is 16×16 and W2 is 32×32). W1 is always a full matrix (small), while W2 can be either low-rank decomposed or a full matrix depending on the rank setting. LoKr tends to produce **smaller models** compared to LoRA at the same rank.
+
+
+日本語
+
+### LoHa
+
+LoHaは重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します:
+
+```
+ΔW = (W1a × W1b) ⊙ (W2a × W2b)
+```
+
+ここで `W1a`, `W1b`, `W2a`, `W2b` はランク `network_dim` の低ランク行列です。LoHaは同じランクのLoRAと比較して学習可能なパラメータ数が **約2倍** になりますが、要素ごとの積により、より複雑な重み構造を捉えることができます。
+
+Conv2d 3x3+層でTucker分解を使用する場合、各ペアにはさらにTuckerテンソル `T` があり、再構成は `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)` となります。
+
+### LoKr
+
+LoKrはKronecker積を使って重みの更新を表現します:
+
+```
+ΔW = W1 ⊗ W2 (低ランクモードでは W2 = W2a × W2b)
+```
+
+元の重みの次元が因数分解されます(例: 512×512の重みが、W1が16×16、W2が32×32に分割されます)。W1は常にフル行列(小さい)で、W2はランク設定に応じて低ランク分解またはフル行列になります。LoKrは同じランクのLoRAと比較して **より小さいモデル** を生成する傾向があります。
+
+
+
+## Inference / 推論
+
+Trained LoHa/LoKr weights are saved in safetensors format, just like LoRA.
+
+
+日本語
+
+学習済みのLoHa/LoKrの重みは、LoRAと同様にsafetensors形式で保存されます。
+
+
+
+### SDXL
+
+For SDXL, use `gen_img.py` with `--network_module` and `--network_weights`, the same way as LoRA:
+
+```bash
+python gen_img.py --ckpt path/to/sdxl.safetensors \
+ --network_module networks.loha --network_weights path/to/loha.safetensors \
+ --prompt "your prompt" ...
+```
+
+Replace `networks.loha` with `networks.lokr` for LoKr weights.
+
+
+日本語
+
+SDXLでは、LoRAと同様に `gen_img.py` で `--network_module` と `--network_weights` を指定します:
+
+```bash
+python gen_img.py --ckpt path/to/sdxl.safetensors \
+ --network_module networks.loha --network_weights path/to/loha.safetensors \
+ --prompt "your prompt" ...
+```
+
+LoKrの重みを使用する場合は `networks.loha` を `networks.lokr` に置き換えてください。
+
+
+
+### Anima
+
+For Anima, use `anima_minimal_inference.py` with the `--lora_weight` argument. LoRA, LoHa, and LoKr weights are automatically detected and merged:
+
+```bash
+python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
+ --lora_weight path/to/loha_or_lokr.safetensors ...
+```
+
+
+日本語
+
+Animaでは、`anima_minimal_inference.py` に `--lora_weight` 引数を指定します。LoRA、LoHa、LoKrの重みは自動的に判定されてマージされます:
+
+```bash
+python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \
+ --lora_weight path/to/loha_or_lokr.safetensors ...
+```
+
+
diff --git a/library/lora_utils.py b/library/lora_utils.py
index 90e3c389..dadad898 100644
--- a/library/lora_utils.py
+++ b/library/lora_utils.py
@@ -1,267 +1,287 @@
-import os
-import re
-from typing import Dict, List, Optional, Union
-import torch
-from tqdm import tqdm
-from library.device_utils import synchronize_device
-from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
-from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
-from library.utils import setup_logging
-
-setup_logging()
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-def filter_lora_state_dict(
- weights_sd: Dict[str, torch.Tensor],
- include_pattern: Optional[str] = None,
- exclude_pattern: Optional[str] = None,
-) -> Dict[str, torch.Tensor]:
- # apply include/exclude patterns
- original_key_count = len(weights_sd.keys())
- if include_pattern is not None:
- regex_include = re.compile(include_pattern)
- weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
- logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
-
- if exclude_pattern is not None:
- original_key_count_ex = len(weights_sd.keys())
- regex_exclude = re.compile(exclude_pattern)
- weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
- logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
-
- if len(weights_sd) != original_key_count:
- remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
- remaining_keys.sort()
- logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
- if len(weights_sd) == 0:
- logger.warning("No keys left after filtering.")
-
- return weights_sd
-
-
-def load_safetensors_with_lora_and_fp8(
- model_files: Union[str, List[str]],
- lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
- lora_multipliers: Optional[List[float]],
- fp8_optimization: bool,
- calc_device: torch.device,
- move_to_device: bool = False,
- dit_weight_dtype: Optional[torch.dtype] = None,
- target_keys: Optional[List[str]] = None,
- exclude_keys: Optional[List[str]] = None,
- disable_numpy_memmap: bool = False,
- weight_transform_hooks: Optional[WeightTransformHooks] = None,
-) -> dict[str, torch.Tensor]:
- """
- Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
-
- Args:
- model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
- lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
- lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
- fp8_optimization (bool): Whether to apply FP8 optimization.
- calc_device (torch.device): Device to calculate on.
- move_to_device (bool): Whether to move tensors to the calculation device after loading.
- target_keys (Optional[List[str]]): Keys to target for optimization.
- exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
- disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
- weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
- """
-
- # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
- if isinstance(model_files, str):
- model_files = [model_files]
-
- extended_model_files = []
- for model_file in model_files:
- split_filenames = get_split_weight_filenames(model_file)
- if split_filenames is not None:
- extended_model_files.extend(split_filenames)
- else:
- extended_model_files.append(model_file)
- model_files = extended_model_files
- logger.info(f"Loading model files: {model_files}")
-
- # load LoRA weights
- weight_hook = None
- if lora_weights_list is None or len(lora_weights_list) == 0:
- lora_weights_list = []
- lora_multipliers = []
- list_of_lora_weight_keys = []
- else:
- list_of_lora_weight_keys = []
- for lora_sd in lora_weights_list:
- lora_weight_keys = set(lora_sd.keys())
- list_of_lora_weight_keys.append(lora_weight_keys)
-
- if lora_multipliers is None:
- lora_multipliers = [1.0] * len(lora_weights_list)
- while len(lora_multipliers) < len(lora_weights_list):
- lora_multipliers.append(1.0)
- if len(lora_multipliers) > len(lora_weights_list):
- lora_multipliers = lora_multipliers[: len(lora_weights_list)]
-
- # Merge LoRA weights into the state dict
- logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
-
- # make hook for LoRA merging
- def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
- nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
-
- if not model_weight_key.endswith(".weight"):
- return model_weight
-
- original_device = model_weight.device
- if original_device != calc_device:
- model_weight = model_weight.to(calc_device) # to make calculation faster
-
- for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
- # check if this weight has LoRA weights
- lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
- found = False
- for prefix in ["lora_unet_", ""]:
- lora_name = prefix + lora_name_without_prefix.replace(".", "_")
- down_key = lora_name + ".lora_down.weight"
- up_key = lora_name + ".lora_up.weight"
- alpha_key = lora_name + ".alpha"
- if down_key in lora_weight_keys and up_key in lora_weight_keys:
- found = True
- break
- if not found:
- continue # no LoRA weights for this model weight
-
- # get LoRA weights
- down_weight = lora_sd[down_key]
- up_weight = lora_sd[up_key]
-
- dim = down_weight.size()[0]
- alpha = lora_sd.get(alpha_key, dim)
- scale = alpha / dim
-
- down_weight = down_weight.to(calc_device)
- up_weight = up_weight.to(calc_device)
-
- original_dtype = model_weight.dtype
- if original_dtype.itemsize == 1: # fp8
- # temporarily convert to float16 for calculation
- model_weight = model_weight.to(torch.float16)
- down_weight = down_weight.to(torch.float16)
- up_weight = up_weight.to(torch.float16)
-
- # W <- W + U * D
- if len(model_weight.size()) == 2:
- # linear
- if len(up_weight.size()) == 4: # use linear projection mismatch
- up_weight = up_weight.squeeze(3).squeeze(2)
- down_weight = down_weight.squeeze(3).squeeze(2)
- model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
- elif down_weight.size()[2:4] == (1, 1):
- # conv2d 1x1
- model_weight = (
- model_weight
- + multiplier
- * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
- * scale
- )
- else:
- # conv2d 3x3
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
- # logger.info(conved.size(), weight.size(), module.stride, module.padding)
- model_weight = model_weight + multiplier * conved * scale
-
- if original_dtype.itemsize == 1: # fp8
- model_weight = model_weight.to(original_dtype) # convert back to original dtype
-
- # remove LoRA keys from set
- lora_weight_keys.remove(down_key)
- lora_weight_keys.remove(up_key)
- if alpha_key in lora_weight_keys:
- lora_weight_keys.remove(alpha_key)
-
- if not keep_on_calc_device and original_device != calc_device:
- model_weight = model_weight.to(original_device) # move back to original device
- return model_weight
-
- weight_hook = weight_hook_func
-
- state_dict = load_safetensors_with_fp8_optimization_and_hook(
- model_files,
- fp8_optimization,
- calc_device,
- move_to_device,
- dit_weight_dtype,
- target_keys,
- exclude_keys,
- weight_hook=weight_hook,
- disable_numpy_memmap=disable_numpy_memmap,
- weight_transform_hooks=weight_transform_hooks,
- )
-
- for lora_weight_keys in list_of_lora_weight_keys:
- # check if all LoRA keys are used
- if len(lora_weight_keys) > 0:
- # if there are still LoRA keys left, it means they are not used in the model
- # this is a warning, not an error
- logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
-
- return state_dict
-
-
-def load_safetensors_with_fp8_optimization_and_hook(
- model_files: list[str],
- fp8_optimization: bool,
- calc_device: torch.device,
- move_to_device: bool = False,
- dit_weight_dtype: Optional[torch.dtype] = None,
- target_keys: Optional[List[str]] = None,
- exclude_keys: Optional[List[str]] = None,
- weight_hook: callable = None,
- disable_numpy_memmap: bool = False,
- weight_transform_hooks: Optional[WeightTransformHooks] = None,
-) -> dict[str, torch.Tensor]:
- """
- Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
- """
- if fp8_optimization:
- logger.info(
- f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
- )
- # dit_weight_dtype is not used because we use fp8 optimization
- state_dict = load_safetensors_with_fp8_optimization(
- model_files,
- calc_device,
- target_keys,
- exclude_keys,
- move_to_device=move_to_device,
- weight_hook=weight_hook,
- disable_numpy_memmap=disable_numpy_memmap,
- weight_transform_hooks=weight_transform_hooks,
- )
- else:
- logger.info(
- f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
- )
- state_dict = {}
- for model_file in model_files:
- with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
- f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
- for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
- if weight_hook is None and move_to_device:
- value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
- else:
- value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
- if weight_hook is not None:
- value = weight_hook(key, value, keep_on_calc_device=move_to_device)
- if move_to_device:
- value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
- elif dit_weight_dtype is not None:
- value = value.to(dit_weight_dtype)
-
- state_dict[key] = value
- if move_to_device:
- synchronize_device(calc_device)
-
- return state_dict
+import os
+import re
+from typing import Dict, List, Optional, Union
+import torch
+from tqdm import tqdm
+from library.device_utils import synchronize_device
+from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
+from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
+from networks.loha import merge_weights_to_tensor as loha_merge
+from networks.lokr import merge_weights_to_tensor as lokr_merge
+
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def filter_lora_state_dict(
+ weights_sd: Dict[str, torch.Tensor],
+ include_pattern: Optional[str] = None,
+ exclude_pattern: Optional[str] = None,
+) -> Dict[str, torch.Tensor]:
+ # apply include/exclude patterns
+ original_key_count = len(weights_sd.keys())
+ if include_pattern is not None:
+ regex_include = re.compile(include_pattern)
+ weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
+ logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
+
+ if exclude_pattern is not None:
+ original_key_count_ex = len(weights_sd.keys())
+ regex_exclude = re.compile(exclude_pattern)
+ weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
+ logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
+
+ if len(weights_sd) != original_key_count:
+ remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
+ remaining_keys.sort()
+ logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
+ if len(weights_sd) == 0:
+ logger.warning("No keys left after filtering.")
+
+ return weights_sd
+
+
+def load_safetensors_with_lora_and_fp8(
+ model_files: Union[str, List[str]],
+ lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
+ lora_multipliers: Optional[List[float]],
+ fp8_optimization: bool,
+ calc_device: torch.device,
+ move_to_device: bool = False,
+ dit_weight_dtype: Optional[torch.dtype] = None,
+ target_keys: Optional[List[str]] = None,
+ exclude_keys: Optional[List[str]] = None,
+ disable_numpy_memmap: bool = False,
+ weight_transform_hooks: Optional[WeightTransformHooks] = None,
+) -> dict[str, torch.Tensor]:
+ """
+ Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
+
+ Args:
+ model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
+ lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
+ lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
+ fp8_optimization (bool): Whether to apply FP8 optimization.
+ calc_device (torch.device): Device to calculate on.
+ move_to_device (bool): Whether to move tensors to the calculation device after loading.
+ dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization.
+ target_keys (Optional[List[str]]): Keys to target for optimization.
+ exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
+ disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
+ weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
+ """
+
+ # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
+ if isinstance(model_files, str):
+ model_files = [model_files]
+
+ extended_model_files = []
+ for model_file in model_files:
+ split_filenames = get_split_weight_filenames(model_file)
+ if split_filenames is not None:
+ extended_model_files.extend(split_filenames)
+ else:
+ extended_model_files.append(model_file)
+ model_files = extended_model_files
+ logger.info(f"Loading model files: {model_files}")
+
+ # load LoRA weights
+ weight_hook = None
+ if lora_weights_list is None or len(lora_weights_list) == 0:
+ lora_weights_list = []
+ lora_multipliers = []
+ list_of_lora_weight_keys = []
+ else:
+ list_of_lora_weight_keys = []
+ for lora_sd in lora_weights_list:
+ lora_weight_keys = set(lora_sd.keys())
+ list_of_lora_weight_keys.append(lora_weight_keys)
+
+ if lora_multipliers is None:
+ lora_multipliers = [1.0] * len(lora_weights_list)
+ while len(lora_multipliers) < len(lora_weights_list):
+ lora_multipliers.append(1.0)
+ if len(lora_multipliers) > len(lora_weights_list):
+ lora_multipliers = lora_multipliers[: len(lora_weights_list)]
+
+ # Merge LoRA weights into the state dict
+ logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
+
+ # make hook for LoRA merging
+ def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
+ nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
+
+ if not model_weight_key.endswith(".weight"):
+ return model_weight
+
+ original_device = model_weight.device
+ if original_device != calc_device:
+ model_weight = model_weight.to(calc_device) # to make calculation faster
+
+ for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
+ # check if this weight has LoRA weights
+ lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
+ found = False
+ for prefix in ["lora_unet_", ""]:
+ lora_name = prefix + lora_name_without_prefix.replace(".", "_")
+ down_key = lora_name + ".lora_down.weight"
+ up_key = lora_name + ".lora_up.weight"
+ alpha_key = lora_name + ".alpha"
+ if down_key in lora_weight_keys and up_key in lora_weight_keys:
+ found = True
+ break
+
+ if found:
+ # Standard LoRA merge
+ # get LoRA weights
+ down_weight = lora_sd[down_key]
+ up_weight = lora_sd[up_key]
+
+ dim = down_weight.size()[0]
+ alpha = lora_sd.get(alpha_key, dim)
+ scale = alpha / dim
+
+ down_weight = down_weight.to(calc_device)
+ up_weight = up_weight.to(calc_device)
+
+ original_dtype = model_weight.dtype
+ if original_dtype.itemsize == 1: # fp8
+ # temporarily convert to float16 for calculation
+ model_weight = model_weight.to(torch.float16)
+ down_weight = down_weight.to(torch.float16)
+ up_weight = up_weight.to(torch.float16)
+
+ # W <- W + U * D
+ if len(model_weight.size()) == 2:
+ # linear
+ if len(up_weight.size()) == 4: # use linear projection mismatch
+ up_weight = up_weight.squeeze(3).squeeze(2)
+ down_weight = down_weight.squeeze(3).squeeze(2)
+ model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ model_weight = (
+ model_weight
+ + multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
+ model_weight = model_weight + multiplier * conved * scale
+
+ if original_dtype.itemsize == 1: # fp8
+ model_weight = model_weight.to(original_dtype) # convert back to original dtype
+
+ # remove LoRA keys from set
+ lora_weight_keys.remove(down_key)
+ lora_weight_keys.remove(up_key)
+ if alpha_key in lora_weight_keys:
+ lora_weight_keys.remove(alpha_key)
+ continue
+
+ # Check for LoHa/LoKr weights with same prefix search
+ for prefix in ["lora_unet_", ""]:
+ lora_name = prefix + lora_name_without_prefix.replace(".", "_")
+ hada_key = lora_name + ".hada_w1_a"
+ lokr_key = lora_name + ".lokr_w1"
+
+ if hada_key in lora_weight_keys:
+ # LoHa merge
+ model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
+ break
+ elif lokr_key in lora_weight_keys:
+ # LoKr merge
+ model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
+ break
+
+ if not keep_on_calc_device and original_device != calc_device:
+ model_weight = model_weight.to(original_device) # move back to original device
+ return model_weight
+
+ weight_hook = weight_hook_func
+
+ state_dict = load_safetensors_with_fp8_optimization_and_hook(
+ model_files,
+ fp8_optimization,
+ calc_device,
+ move_to_device,
+ dit_weight_dtype,
+ target_keys,
+ exclude_keys,
+ weight_hook=weight_hook,
+ disable_numpy_memmap=disable_numpy_memmap,
+ weight_transform_hooks=weight_transform_hooks,
+ )
+
+ for lora_weight_keys in list_of_lora_weight_keys:
+ # check if all LoRA keys are used
+ if len(lora_weight_keys) > 0:
+ # if there are still LoRA keys left, it means they are not used in the model
+ # this is a warning, not an error
+ logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
+
+ return state_dict
+
+
+def load_safetensors_with_fp8_optimization_and_hook(
+ model_files: list[str],
+ fp8_optimization: bool,
+ calc_device: torch.device,
+ move_to_device: bool = False,
+ dit_weight_dtype: Optional[torch.dtype] = None,
+ target_keys: Optional[List[str]] = None,
+ exclude_keys: Optional[List[str]] = None,
+ weight_hook: callable = None,
+ disable_numpy_memmap: bool = False,
+ weight_transform_hooks: Optional[WeightTransformHooks] = None,
+) -> dict[str, torch.Tensor]:
+ """
+ Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
+ """
+ if fp8_optimization:
+ logger.info(
+ f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
+ )
+ # dit_weight_dtype is not used because we use fp8 optimization
+ state_dict = load_safetensors_with_fp8_optimization(
+ model_files,
+ calc_device,
+ target_keys,
+ exclude_keys,
+ move_to_device=move_to_device,
+ weight_hook=weight_hook,
+ disable_numpy_memmap=disable_numpy_memmap,
+ weight_transform_hooks=weight_transform_hooks,
+ )
+ else:
+ logger.info(
+ f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
+ )
+ state_dict = {}
+ for model_file in model_files:
+ with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
+ f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
+ for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
+ if weight_hook is None and move_to_device:
+ value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
+ else:
+ value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
+ if weight_hook is not None:
+ value = weight_hook(key, value, keep_on_calc_device=move_to_device)
+ if move_to_device:
+ value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
+ elif dit_weight_dtype is not None:
+ value = value.to(dit_weight_dtype)
+
+ state_dict[key] = value
+ if move_to_device:
+ synchronize_device(calc_device)
+
+ return state_dict
diff --git a/networks/loha.py b/networks/loha.py
new file mode 100644
index 00000000..8734f9c5
--- /dev/null
+++ b/networks/loha.py
@@ -0,0 +1,643 @@
+# LoHa (Low-rank Hadamard Product) network module
+# Reference: https://arxiv.org/abs/2108.06098
+#
+# Based on the LyCORIS project by KohakuBlueleaf
+# https://github.com/KohakuBlueleaf/LyCORIS
+
+import ast
+import os
+import logging
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
+from library.utils import setup_logging
+
+setup_logging()
+logger = logging.getLogger(__name__)
+
+
+class HadaWeight(torch.autograd.Function):
+ """Efficient Hadamard product forward/backward for LoHa.
+
+ Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
+ that recomputes intermediates instead of storing them.
+ """
+
+ @staticmethod
+ def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
+ if scale is None:
+ scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
+ ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
+ diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
+ return diff_weight
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
+ grad_out = grad_out * scale
+ temp = grad_out * (w2a @ w2b)
+ grad_w1a = temp @ w1b.T
+ grad_w1b = w1a.T @ temp
+
+ temp = grad_out * (w1a @ w1b)
+ grad_w2a = temp @ w2b.T
+ grad_w2b = w2a.T @ temp
+
+ del temp
+ return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
+
+
+class HadaWeightTucker(torch.autograd.Function):
+ """Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
+
+ Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
+ where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
+ Compatible with LyCORIS parameter naming convention.
+ """
+
+ @staticmethod
+ def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
+ if scale is None:
+ scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
+ ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
+
+ rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
+ rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
+
+ return rebuild1 * rebuild2 * scale
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ (t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
+ grad_out = grad_out * scale
+
+ # Gradients for w1a, w1b, t1 (using rebuild2)
+ temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
+ rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
+
+ grad_w = rebuild * grad_out
+ del rebuild
+
+ grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
+ grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
+ del grad_w, temp
+
+ grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
+ grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
+ del grad_temp
+
+ # Gradients for w2a, w2b, t2 (using rebuild1)
+ temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
+ rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
+
+ grad_w = rebuild * grad_out
+ del rebuild
+
+ grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
+ grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
+ del grad_w, temp
+
+ grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
+ grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
+ del grad_temp
+
+ return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
+
+
+class LoHaModule(torch.nn.Module):
+ """LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ use_tucker=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.lora_name = lora_name
+ self.lora_dim = lora_dim
+
+ is_conv2d = org_module.__class__.__name__ == "Conv2d"
+ if is_conv2d:
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ kernel_size = org_module.kernel_size
+ self.is_conv = True
+ self.stride = org_module.stride
+ self.padding = org_module.padding
+ self.dilation = org_module.dilation
+ self.groups = org_module.groups
+ self.kernel_size = kernel_size
+
+ self.tucker = use_tucker and any(k != 1 for k in kernel_size)
+
+ if kernel_size == (1, 1):
+ self.conv_mode = "1x1"
+ elif self.tucker:
+ self.conv_mode = "tucker"
+ else:
+ self.conv_mode = "flat"
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+ self.is_conv = False
+ self.tucker = False
+ self.conv_mode = None
+ self.kernel_size = None
+
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # Create parameters based on mode
+ if self.conv_mode == "tucker":
+ # Tucker decomposition for Conv2d 3x3+
+ # Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
+ self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
+ self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
+ self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
+ self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
+ self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
+ self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
+
+ # LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
+ torch.nn.init.normal_(self.hada_t1, std=0.1)
+ torch.nn.init.normal_(self.hada_t2, std=0.1)
+ torch.nn.init.normal_(self.hada_w1_b, std=1.0)
+ torch.nn.init.constant_(self.hada_w1_a, 0)
+ torch.nn.init.normal_(self.hada_w2_b, std=1.0)
+ torch.nn.init.normal_(self.hada_w2_a, std=0.1)
+ elif self.conv_mode == "flat":
+ # Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
+ k_prod = 1
+ for k in kernel_size:
+ k_prod *= k
+ flat_in = in_dim * k_prod
+
+ self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
+ self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
+ self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
+ self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
+
+ torch.nn.init.normal_(self.hada_w1_a, std=0.1)
+ torch.nn.init.normal_(self.hada_w1_b, std=1.0)
+ torch.nn.init.constant_(self.hada_w2_a, 0)
+ torch.nn.init.normal_(self.hada_w2_b, std=1.0)
+ else:
+ # Linear or Conv2d 1x1
+ self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
+ self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
+ self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
+ self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
+
+ torch.nn.init.normal_(self.hada_w1_a, std=0.1)
+ torch.nn.init.normal_(self.hada_w1_b, std=1.0)
+ torch.nn.init.constant_(self.hada_w2_a, 0)
+ torch.nn.init.normal_(self.hada_w2_b, std=1.0)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy()
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha))
+
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def get_diff_weight(self):
+ """Return materialized weight delta.
+
+ Returns:
+ - Linear: 2D tensor (out_dim, in_dim)
+ - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
+ - Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
+ - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
+ """
+ if self.tucker:
+ scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
+ return HadaWeightTucker.apply(
+ self.hada_t1, self.hada_w1_b, self.hada_w1_a,
+ self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
+ )
+ elif self.conv_mode == "flat":
+ scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
+ diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
+ return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
+ else:
+ scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
+ return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
+
+ def forward(self, x):
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ diff_weight = self.get_diff_weight()
+
+ # rank dropout (applied on output dimension)
+ if self.rank_dropout is not None and self.training:
+ drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
+ drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
+ diff_weight = diff_weight * drop
+ scale = 1.0 / (1.0 - self.rank_dropout)
+ else:
+ scale = 1.0
+
+ if self.is_conv:
+ if self.conv_mode == "1x1":
+ diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
+ return org_forwarded + F.conv2d(
+ x, diff_weight, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups
+ ) * self.multiplier * scale
+ else:
+ # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
+ return org_forwarded + F.conv2d(
+ x, diff_weight, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups
+ ) * self.multiplier * scale
+ else:
+ return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+class LoHaInfModule(LoHaModule):
+ """LoHa module for inference. Supports merge_to and get_weight."""
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ **kwargs,
+ ):
+ # no dropout for inference; pass use_tucker from kwargs
+ use_tucker = kwargs.pop("use_tucker", False)
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
+
+ self.org_module_ref = [org_module]
+ self.enabled = True
+ self.network: AdditionalNetwork = None
+
+ def set_network(self, network):
+ self.network = network
+
+ def merge_to(self, sd, dtype, device):
+ # extract weight from org_module
+ org_sd = self.org_module.state_dict()
+ weight = org_sd["weight"]
+ org_dtype = weight.dtype
+ org_device = weight.device
+ weight = weight.to(torch.float)
+
+ if dtype is None:
+ dtype = org_dtype
+ if device is None:
+ device = org_device
+
+ # get LoHa weights
+ w1a = sd["hada_w1_a"].to(torch.float).to(device)
+ w1b = sd["hada_w1_b"].to(torch.float).to(device)
+ w2a = sd["hada_w2_a"].to(torch.float).to(device)
+ w2b = sd["hada_w2_b"].to(torch.float).to(device)
+
+ if self.tucker:
+ # Tucker mode
+ t1 = sd["hada_t1"].to(torch.float).to(device)
+ t2 = sd["hada_t2"].to(torch.float).to(device)
+ rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
+ rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
+ diff_weight = rebuild1 * rebuild2 * self.scale
+ else:
+ diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
+ # reshape diff_weight to match original weight shape if needed
+ if diff_weight.shape != weight.shape:
+ diff_weight = diff_weight.reshape(weight.shape)
+
+ weight = weight.to(device) + self.multiplier * diff_weight
+
+ org_sd["weight"] = weight.to(dtype)
+ self.org_module.load_state_dict(org_sd)
+
+ def get_weight(self, multiplier=None):
+ if multiplier is None:
+ multiplier = self.multiplier
+
+ if self.tucker:
+ t1 = self.hada_t1.to(torch.float)
+ w1a = self.hada_w1_a.to(torch.float)
+ w1b = self.hada_w1_b.to(torch.float)
+ t2 = self.hada_t2.to(torch.float)
+ w2a = self.hada_w2_a.to(torch.float)
+ w2b = self.hada_w2_b.to(torch.float)
+ rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
+ rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
+ weight = rebuild1 * rebuild2 * self.scale * multiplier
+ else:
+ w1a = self.hada_w1_a.to(torch.float)
+ w1b = self.hada_w1_b.to(torch.float)
+ w2a = self.hada_w2_a.to(torch.float)
+ w2b = self.hada_w2_b.to(torch.float)
+ weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
+
+ if self.is_conv:
+ if self.conv_mode == "1x1":
+ weight = weight.unsqueeze(2).unsqueeze(3)
+ elif self.conv_mode == "flat":
+ weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
+
+ return weight
+
+ def default_forward(self, x):
+ diff_weight = self.get_diff_weight()
+ if self.is_conv:
+ if self.conv_mode == "1x1":
+ diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
+ return self.org_forward(x) + F.conv2d(
+ x, diff_weight, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups
+ ) * self.multiplier
+ else:
+ return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
+
+ def forward(self, x):
+ if not self.enabled:
+ return self.org_forward(x)
+ return self.default_forward(x)
+
+
+def create_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae,
+ text_encoder,
+ unet,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ """Create a LoHa network. Called by train_network.py via network_module.create_network()."""
+ if network_dim is None:
+ network_dim = 4
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ # handle text_encoder as list
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
+
+ # detect architecture
+ arch_config = detect_arch_config(unet, text_encoders)
+
+ # train LLM adapter
+ train_llm_adapter = kwargs.get("train_llm_adapter", "false")
+ if train_llm_adapter is not None:
+ train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
+
+ # exclude patterns
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is None:
+ exclude_patterns = []
+ else:
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+ if not isinstance(exclude_patterns, list):
+ exclude_patterns = [exclude_patterns]
+
+ # add default exclude patterns from arch config
+ exclude_patterns.extend(arch_config.default_excludes)
+
+ # include patterns
+ include_patterns = kwargs.get("include_patterns", None)
+ if include_patterns is not None:
+ include_patterns = ast.literal_eval(include_patterns)
+ if not isinstance(include_patterns, list):
+ include_patterns = [include_patterns]
+
+ # rank/module dropout
+ rank_dropout = kwargs.get("rank_dropout", None)
+ if rank_dropout is not None:
+ rank_dropout = float(rank_dropout)
+ module_dropout = kwargs.get("module_dropout", None)
+ if module_dropout is not None:
+ module_dropout = float(module_dropout)
+
+ # conv dim/alpha for Conv2d 3x3
+ conv_lora_dim = kwargs.get("conv_dim", None)
+ conv_alpha = kwargs.get("conv_alpha", None)
+ if conv_lora_dim is not None:
+ conv_lora_dim = int(conv_lora_dim)
+ if conv_alpha is None:
+ conv_alpha = 1.0
+ else:
+ conv_alpha = float(conv_alpha)
+
+ # Tucker decomposition for Conv2d 3x3
+ use_tucker = kwargs.get("use_tucker", "false")
+ if use_tucker is not None:
+ use_tucker = True if str(use_tucker).lower() == "true" else False
+
+ # verbose
+ verbose = kwargs.get("verbose", "false")
+ if verbose is not None:
+ verbose = True if str(verbose).lower() == "true" else False
+
+ # regex-specific learning rates / dimensions
+ network_reg_lrs = kwargs.get("network_reg_lrs", None)
+ reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
+
+ network_reg_dims = kwargs.get("network_reg_dims", None)
+ reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
+
+ network = AdditionalNetwork(
+ text_encoders,
+ unet,
+ arch_config=arch_config,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ module_class=LoHaModule,
+ module_kwargs={"use_tucker": use_tucker},
+ conv_lora_dim=conv_lora_dim,
+ conv_alpha=conv_alpha,
+ train_llm_adapter=train_llm_adapter,
+ exclude_patterns=exclude_patterns,
+ include_patterns=include_patterns,
+ reg_dims=reg_dims,
+ reg_lrs=reg_lrs,
+ verbose=verbose,
+ )
+
+ # LoRA+ support
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
+ loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
+ loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
+ loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
+ loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
+
+ return network
+
+
+def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
+ """Create a LoHa network from saved weights. Called by train_network.py."""
+ if weights_sd is None:
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ # detect dim/alpha from weights
+ modules_dim = {}
+ modules_alpha = {}
+ train_llm_adapter = False
+ for key, value in weights_sd.items():
+ if "." not in key:
+ continue
+
+ lora_name = key.split(".")[0]
+ if "alpha" in key:
+ modules_alpha[lora_name] = value
+ elif "hada_w1_b" in key:
+ dim = value.shape[0]
+ modules_dim[lora_name] = dim
+
+ if "llm_adapter" in lora_name:
+ train_llm_adapter = True
+
+ # detect Tucker mode from weights
+ use_tucker = any("hada_t1" in key for key in weights_sd.keys())
+
+ # handle text_encoder as list
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
+
+ # detect architecture
+ arch_config = detect_arch_config(unet, text_encoders)
+
+ module_class = LoHaInfModule if for_inference else LoHaModule
+ module_kwargs = {"use_tucker": use_tucker}
+
+ network = AdditionalNetwork(
+ text_encoders,
+ unet,
+ arch_config=arch_config,
+ multiplier=multiplier,
+ modules_dim=modules_dim,
+ modules_alpha=modules_alpha,
+ module_class=module_class,
+ module_kwargs=module_kwargs,
+ train_llm_adapter=train_llm_adapter,
+ )
+ return network, weights_sd
+
+
+def merge_weights_to_tensor(
+ model_weight: torch.Tensor,
+ lora_name: str,
+ lora_sd: Dict[str, torch.Tensor],
+ lora_weight_keys: set,
+ multiplier: float,
+ calc_device: torch.device,
+) -> torch.Tensor:
+ """Merge LoHa weights directly into a model weight tensor.
+
+ Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
+ No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
+ Returns model_weight unchanged if no matching LoHa keys found.
+ """
+ w1a_key = lora_name + ".hada_w1_a"
+ w1b_key = lora_name + ".hada_w1_b"
+ w2a_key = lora_name + ".hada_w2_a"
+ w2b_key = lora_name + ".hada_w2_b"
+ t1_key = lora_name + ".hada_t1"
+ t2_key = lora_name + ".hada_t2"
+ alpha_key = lora_name + ".alpha"
+
+ if w1a_key not in lora_weight_keys:
+ return model_weight
+
+ w1a = lora_sd[w1a_key].to(calc_device)
+ w1b = lora_sd[w1b_key].to(calc_device)
+ w2a = lora_sd[w2a_key].to(calc_device)
+ w2b = lora_sd[w2b_key].to(calc_device)
+
+ has_tucker = t1_key in lora_weight_keys
+
+ dim = w1b.shape[0]
+ alpha = lora_sd.get(alpha_key, torch.tensor(dim))
+ if isinstance(alpha, torch.Tensor):
+ alpha = alpha.item()
+ scale = alpha / dim
+
+ original_dtype = model_weight.dtype
+ if original_dtype.itemsize == 1: # fp8
+ model_weight = model_weight.to(torch.float16)
+ w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
+ w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
+
+ if has_tucker:
+ # Tucker decomposition: rebuild via einsum
+ t1 = lora_sd[t1_key].to(calc_device)
+ t2 = lora_sd[t2_key].to(calc_device)
+ if original_dtype.itemsize == 1:
+ t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
+ rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
+ rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
+ diff_weight = rebuild1 * rebuild2 * scale
+ else:
+ # Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
+ diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
+
+ # Reshape diff_weight to match model_weight shape if needed
+ # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
+ if diff_weight.shape != model_weight.shape:
+ diff_weight = diff_weight.reshape(model_weight.shape)
+
+ model_weight = model_weight + multiplier * diff_weight
+
+ if original_dtype.itemsize == 1:
+ model_weight = model_weight.to(original_dtype)
+
+ # remove consumed keys
+ consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
+ if has_tucker:
+ consumed.extend([t1_key, t2_key])
+ for key in consumed:
+ lora_weight_keys.discard(key)
+
+ return model_weight
diff --git a/networks/lokr.py b/networks/lokr.py
new file mode 100644
index 00000000..03b50ca0
--- /dev/null
+++ b/networks/lokr.py
@@ -0,0 +1,683 @@
+# LoKr (Low-rank Kronecker Product) network module
+# Reference: https://arxiv.org/abs/2309.14859
+#
+# Based on the LyCORIS project by KohakuBlueleaf
+# https://github.com/KohakuBlueleaf/LyCORIS
+
+import ast
+import math
+import os
+import logging
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
+from library.utils import setup_logging
+
+setup_logging()
+logger = logging.getLogger(__name__)
+
+
+def factorization(dimension: int, factor: int = -1) -> tuple:
+ """Return a tuple of two values whose product equals dimension,
+ optimized for balanced factors.
+
+ In LoKr, the first value is for the weight scale (smaller),
+ and the second value is for the weight (larger).
+
+ Examples:
+ factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32)
+ factor=4: 128 -> (4, 32), 512 -> (4, 128)
+ """
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m < n:
+ new_m = m + 1
+ while dimension % new_m != 0:
+ new_m += 1
+ new_n = dimension // new_m
+ if new_m + new_n > length or new_m > factor:
+ break
+ else:
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
+
+
+def make_kron(w1, w2, scale):
+ """Compute Kronecker product of w1 and w2, scaled by scale."""
+ if w1.dim() != w2.dim():
+ for _ in range(w2.dim() - w1.dim()):
+ w1 = w1.unsqueeze(-1)
+ w2 = w2.contiguous()
+ rebuild = torch.kron(w1, w2)
+ if scale != 1:
+ rebuild = rebuild * scale
+ return rebuild
+
+
+def rebuild_tucker(t, wa, wb):
+ """Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb).
+
+ Compatible with LyCORIS convention.
+ """
+ return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb)
+
+
+class LoKrModule(torch.nn.Module):
+ """LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ factor=-1,
+ use_tucker=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.lora_name = lora_name
+ self.lora_dim = lora_dim
+
+ is_conv2d = org_module.__class__.__name__ == "Conv2d"
+ if is_conv2d:
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ kernel_size = org_module.kernel_size
+ self.is_conv = True
+ self.stride = org_module.stride
+ self.padding = org_module.padding
+ self.dilation = org_module.dilation
+ self.groups = org_module.groups
+ self.kernel_size = kernel_size
+
+ self.tucker = use_tucker and any(k != 1 for k in kernel_size)
+
+ if kernel_size == (1, 1):
+ self.conv_mode = "1x1"
+ elif self.tucker:
+ self.conv_mode = "tucker"
+ else:
+ self.conv_mode = "flat"
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+ self.is_conv = False
+ self.tucker = False
+ self.conv_mode = None
+ self.kernel_size = None
+
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ factor = int(factor)
+ self.use_w2 = False
+
+ # Factorize dimensions
+ in_m, in_n = factorization(in_dim, factor)
+ out_l, out_k = factorization(out_dim, factor)
+
+ # w1 is always a full matrix (the "scale" factor, small)
+ self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
+
+ # w2: depends on mode
+ if self.conv_mode in ("tucker", "flat"):
+ # Conv2d 3x3+ modes
+ k_size = kernel_size
+
+ if lora_dim >= max(out_k, in_n) / 2:
+ # Full matrix mode (includes kernel dimensions)
+ self.use_w2 = True
+ self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size))
+ logger.warning(
+ f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
+ f"and factor={factor}, using full matrix mode for Conv2d."
+ )
+ elif self.tucker:
+ # Tucker mode: separate kernel into t2 tensor
+ self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size))
+ self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k))
+ self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
+ else:
+ # Non-Tucker: flatten kernel into w2_b
+ k_prod = 1
+ for k in k_size:
+ k_prod *= k
+ self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
+ self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod))
+ else:
+ # Linear or Conv2d 1x1
+ if lora_dim < max(out_k, in_n) / 2:
+ self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
+ self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
+ else:
+ self.use_w2 = True
+ self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
+ if lora_dim >= max(out_k, in_n) / 2:
+ logger.warning(
+ f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
+ f"and factor={factor}, using full matrix mode."
+ )
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy()
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
+ # if both w1 and w2 are full matrices, use scale = 1
+ if self.use_w2:
+ alpha = lora_dim
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha))
+
+ # Initialization
+ torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
+ if self.use_w2:
+ torch.nn.init.constant_(self.lokr_w2, 0)
+ else:
+ if self.tucker:
+ torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
+ torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
+ torch.nn.init.constant_(self.lokr_w2_b, 0)
+ # Ensures ΔW = kron(w1, 0) = 0 at init
+
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def get_diff_weight(self):
+ """Return materialized weight delta.
+
+ Returns:
+ - Linear: 2D tensor (out_dim, in_dim)
+ - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
+ - Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2)
+ - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D
+ """
+ w1 = self.lokr_w1
+
+ if self.use_w2:
+ w2 = self.lokr_w2
+ elif self.tucker:
+ w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
+ else:
+ w2 = self.lokr_w2_a @ self.lokr_w2_b
+
+ result = make_kron(w1, w2, self.scale)
+
+ # For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D
+ if self.conv_mode == "flat" and result.dim() == 2:
+ result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size)
+
+ return result
+
+ def forward(self, x):
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ diff_weight = self.get_diff_weight()
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
+ drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
+ diff_weight = diff_weight * drop
+ scale = 1.0 / (1.0 - self.rank_dropout)
+ else:
+ scale = 1.0
+
+ if self.is_conv:
+ if self.conv_mode == "1x1":
+ diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
+ return org_forwarded + F.conv2d(
+ x, diff_weight, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups
+ ) * self.multiplier * scale
+ else:
+ # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
+ return org_forwarded + F.conv2d(
+ x, diff_weight, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups
+ ) * self.multiplier * scale
+ else:
+ return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+class LoKrInfModule(LoKrModule):
+ """LoKr module for inference. Supports merge_to and get_weight."""
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ **kwargs,
+ ):
+ # no dropout for inference; pass factor and use_tucker from kwargs
+ factor = kwargs.pop("factor", -1)
+ use_tucker = kwargs.pop("use_tucker", False)
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker)
+
+ self.org_module_ref = [org_module]
+ self.enabled = True
+ self.network: AdditionalNetwork = None
+
+ def set_network(self, network):
+ self.network = network
+
+ def merge_to(self, sd, dtype, device):
+ # extract weight from org_module
+ org_sd = self.org_module.state_dict()
+ weight = org_sd["weight"]
+ org_dtype = weight.dtype
+ org_device = weight.device
+ weight = weight.to(torch.float)
+
+ if dtype is None:
+ dtype = org_dtype
+ if device is None:
+ device = org_device
+
+ # get LoKr weights
+ w1 = sd["lokr_w1"].to(torch.float).to(device)
+
+ if "lokr_w2" in sd:
+ w2 = sd["lokr_w2"].to(torch.float).to(device)
+ elif "lokr_t2" in sd:
+ # Tucker mode
+ t2 = sd["lokr_t2"].to(torch.float).to(device)
+ w2a = sd["lokr_w2_a"].to(torch.float).to(device)
+ w2b = sd["lokr_w2_b"].to(torch.float).to(device)
+ w2 = rebuild_tucker(t2, w2a, w2b)
+ else:
+ w2a = sd["lokr_w2_a"].to(torch.float).to(device)
+ w2b = sd["lokr_w2_b"].to(torch.float).to(device)
+ w2 = w2a @ w2b
+
+ # compute ΔW via Kronecker product
+ diff_weight = make_kron(w1, w2, self.scale)
+
+ # reshape diff_weight to match original weight shape if needed
+ if diff_weight.shape != weight.shape:
+ diff_weight = diff_weight.reshape(weight.shape)
+
+ weight = weight.to(device) + self.multiplier * diff_weight
+
+ org_sd["weight"] = weight.to(dtype)
+ self.org_module.load_state_dict(org_sd)
+
+ def get_weight(self, multiplier=None):
+ if multiplier is None:
+ multiplier = self.multiplier
+
+ w1 = self.lokr_w1.to(torch.float)
+
+ if self.use_w2:
+ w2 = self.lokr_w2.to(torch.float)
+ elif self.tucker:
+ w2 = rebuild_tucker(
+ self.lokr_t2.to(torch.float),
+ self.lokr_w2_a.to(torch.float),
+ self.lokr_w2_b.to(torch.float),
+ )
+ else:
+ w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
+
+ weight = make_kron(w1, w2, self.scale) * multiplier
+
+ # reshape to match original weight shape if needed
+ if self.is_conv:
+ if self.conv_mode == "1x1":
+ weight = weight.unsqueeze(2).unsqueeze(3)
+ elif self.conv_mode == "flat" and weight.dim() == 2:
+ weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
+ # Tucker and full matrix modes: already 4D from kron
+
+ return weight
+
+ def default_forward(self, x):
+ diff_weight = self.get_diff_weight()
+ if self.is_conv:
+ if self.conv_mode == "1x1":
+ diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
+ return self.org_forward(x) + F.conv2d(
+ x, diff_weight, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups
+ ) * self.multiplier
+ else:
+ return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
+
+ def forward(self, x):
+ if not self.enabled:
+ return self.org_forward(x)
+ return self.default_forward(x)
+
+
+def create_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ vae,
+ text_encoder,
+ unet,
+ neuron_dropout: Optional[float] = None,
+ **kwargs,
+):
+ """Create a LoKr network. Called by train_network.py via network_module.create_network()."""
+ if network_dim is None:
+ network_dim = 4
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ # handle text_encoder as list
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
+
+ # detect architecture
+ arch_config = detect_arch_config(unet, text_encoders)
+
+ # train LLM adapter
+ train_llm_adapter = kwargs.get("train_llm_adapter", "false")
+ if train_llm_adapter is not None:
+ train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
+
+ # exclude patterns
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is None:
+ exclude_patterns = []
+ else:
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+ if not isinstance(exclude_patterns, list):
+ exclude_patterns = [exclude_patterns]
+
+ # add default exclude patterns from arch config
+ exclude_patterns.extend(arch_config.default_excludes)
+
+ # include patterns
+ include_patterns = kwargs.get("include_patterns", None)
+ if include_patterns is not None:
+ include_patterns = ast.literal_eval(include_patterns)
+ if not isinstance(include_patterns, list):
+ include_patterns = [include_patterns]
+
+ # rank/module dropout
+ rank_dropout = kwargs.get("rank_dropout", None)
+ if rank_dropout is not None:
+ rank_dropout = float(rank_dropout)
+ module_dropout = kwargs.get("module_dropout", None)
+ if module_dropout is not None:
+ module_dropout = float(module_dropout)
+
+ # conv dim/alpha for Conv2d 3x3
+ conv_lora_dim = kwargs.get("conv_dim", None)
+ conv_alpha = kwargs.get("conv_alpha", None)
+ if conv_lora_dim is not None:
+ conv_lora_dim = int(conv_lora_dim)
+ if conv_alpha is None:
+ conv_alpha = 1.0
+ else:
+ conv_alpha = float(conv_alpha)
+
+ # Tucker decomposition for Conv2d 3x3
+ use_tucker = kwargs.get("use_tucker", "false")
+ if use_tucker is not None:
+ use_tucker = True if str(use_tucker).lower() == "true" else False
+
+ # factor for LoKr
+ factor = int(kwargs.get("factor", -1))
+
+ # verbose
+ verbose = kwargs.get("verbose", "false")
+ if verbose is not None:
+ verbose = True if str(verbose).lower() == "true" else False
+
+ # regex-specific learning rates / dimensions
+ network_reg_lrs = kwargs.get("network_reg_lrs", None)
+ reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
+
+ network_reg_dims = kwargs.get("network_reg_dims", None)
+ reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
+
+ network = AdditionalNetwork(
+ text_encoders,
+ unet,
+ arch_config=arch_config,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ module_class=LoKrModule,
+ module_kwargs={"factor": factor, "use_tucker": use_tucker},
+ conv_lora_dim=conv_lora_dim,
+ conv_alpha=conv_alpha,
+ train_llm_adapter=train_llm_adapter,
+ exclude_patterns=exclude_patterns,
+ include_patterns=include_patterns,
+ reg_dims=reg_dims,
+ reg_lrs=reg_lrs,
+ verbose=verbose,
+ )
+
+ # LoRA+ support
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
+ loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
+ loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
+ loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
+ loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
+
+ return network
+
+
+def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
+ """Create a LoKr network from saved weights. Called by train_network.py."""
+ if weights_sd is None:
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ # detect dim/alpha from weights
+ modules_dim = {}
+ modules_alpha = {}
+ train_llm_adapter = False
+ use_tucker = False
+ for key, value in weights_sd.items():
+ if "." not in key:
+ continue
+
+ lora_name = key.split(".")[0]
+ if "alpha" in key:
+ modules_alpha[lora_name] = value
+ elif "lokr_w2_a" in key:
+ # low-rank mode: dim detection depends on Tucker vs non-Tucker
+ if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd:
+ # Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0]
+ dim = value.shape[0]
+ else:
+ # Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1]
+ dim = value.shape[1]
+ modules_dim[lora_name] = dim
+ elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
+ # full matrix mode: set dim large enough to trigger full-matrix path
+ if lora_name not in modules_dim:
+ modules_dim[lora_name] = max(value.shape[0], value.shape[1])
+
+ if "lokr_t2" in key:
+ use_tucker = True
+
+ if "llm_adapter" in lora_name:
+ train_llm_adapter = True
+
+ # handle text_encoder as list
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
+
+ # detect architecture
+ arch_config = detect_arch_config(unet, text_encoders)
+
+ # extract factor for LoKr
+ factor = int(kwargs.get("factor", -1))
+
+ module_class = LoKrInfModule if for_inference else LoKrModule
+ module_kwargs = {"factor": factor, "use_tucker": use_tucker}
+
+ network = AdditionalNetwork(
+ text_encoders,
+ unet,
+ arch_config=arch_config,
+ multiplier=multiplier,
+ modules_dim=modules_dim,
+ modules_alpha=modules_alpha,
+ module_class=module_class,
+ module_kwargs=module_kwargs,
+ train_llm_adapter=train_llm_adapter,
+ )
+ return network, weights_sd
+
+
+def merge_weights_to_tensor(
+ model_weight: torch.Tensor,
+ lora_name: str,
+ lora_sd: Dict[str, torch.Tensor],
+ lora_weight_keys: set,
+ multiplier: float,
+ calc_device: torch.device,
+) -> torch.Tensor:
+ """Merge LoKr weights directly into a model weight tensor.
+
+ Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
+ No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
+ Returns model_weight unchanged if no matching LoKr keys found.
+ """
+ w1_key = lora_name + ".lokr_w1"
+ w2_key = lora_name + ".lokr_w2"
+ w2a_key = lora_name + ".lokr_w2_a"
+ w2b_key = lora_name + ".lokr_w2_b"
+ t2_key = lora_name + ".lokr_t2"
+ alpha_key = lora_name + ".alpha"
+
+ if w1_key not in lora_weight_keys:
+ return model_weight
+
+ w1 = lora_sd[w1_key].to(calc_device)
+
+ # determine mode: full matrix vs Tucker vs low-rank
+ has_tucker = t2_key in lora_weight_keys
+
+ if w2a_key in lora_weight_keys:
+ w2a = lora_sd[w2a_key].to(calc_device)
+ w2b = lora_sd[w2b_key].to(calc_device)
+
+ if has_tucker:
+ # Tucker: w2a = (rank, out_k), dim = rank
+ dim = w2a.shape[0]
+ else:
+ # Non-Tucker low-rank: w2a = (out_k, rank), dim = rank
+ dim = w2a.shape[1]
+
+ consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
+ if has_tucker:
+ consumed_keys.append(t2_key)
+ elif w2_key in lora_weight_keys:
+ # full matrix mode
+ w2a = None
+ w2b = None
+ dim = None
+ consumed_keys = [w1_key, w2_key, alpha_key]
+ else:
+ return model_weight
+
+ alpha = lora_sd.get(alpha_key, None)
+ if alpha is not None and isinstance(alpha, torch.Tensor):
+ alpha = alpha.item()
+
+ # compute scale
+ if w2a is not None:
+ if alpha is None:
+ alpha = dim
+ scale = alpha / dim
+ else:
+ # full matrix mode: scale = 1.0
+ scale = 1.0
+
+ original_dtype = model_weight.dtype
+ if original_dtype.itemsize == 1: # fp8
+ model_weight = model_weight.to(torch.float16)
+ w1 = w1.to(torch.float16)
+ if w2a is not None:
+ w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
+
+ # compute w2
+ if w2a is not None:
+ if has_tucker:
+ t2 = lora_sd[t2_key].to(calc_device)
+ if original_dtype.itemsize == 1:
+ t2 = t2.to(torch.float16)
+ w2 = rebuild_tucker(t2, w2a, w2b)
+ else:
+ w2 = w2a @ w2b
+ else:
+ w2 = lora_sd[w2_key].to(calc_device)
+ if original_dtype.itemsize == 1:
+ w2 = w2.to(torch.float16)
+
+ # ΔW = kron(w1, w2) * scale
+ diff_weight = make_kron(w1, w2, scale)
+
+ # Reshape diff_weight to match model_weight shape if needed
+ # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
+ if diff_weight.shape != model_weight.shape:
+ diff_weight = diff_weight.reshape(model_weight.shape)
+
+ model_weight = model_weight + multiplier * diff_weight
+
+ if original_dtype.itemsize == 1:
+ model_weight = model_weight.to(original_dtype)
+
+ # remove consumed keys
+ for key in consumed_keys:
+ lora_weight_keys.discard(key)
+
+ return model_weight
diff --git a/networks/lora_anima.py b/networks/lora_anima.py
index 9413e8c8..4cff2819 100644
--- a/networks/lora_anima.py
+++ b/networks/lora_anima.py
@@ -1,11 +1,11 @@
# LoRA network module for Anima
import ast
+import math
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.utils import setup_logging
-from networks.lora_flux import LoRAModule, LoRAInfModule
import logging
@@ -13,6 +13,213 @@ setup_logging()
logger = logging.getLogger(__name__)
+class LoRAModule(torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ ):
+ """
+ if alpha == 0 or None, alpha is rank (no scaling).
+ """
+ super().__init__()
+ self.lora_name = lora_name
+
+ if org_module.__class__.__name__ == "Conv2d":
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ self.lora_dim = lora_dim
+
+ if org_module.__class__.__name__ == "Conv2d":
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+ else:
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.zeros_(self.lora_up.weight)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
+
+ # same as microsoft's
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+
+ del self.org_module
+
+ def forward(self, x):
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ lx = self.lora_down(x)
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if isinstance(self.lora_down, torch.nn.Conv2d):
+ # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1]
+ mask = mask.unsqueeze(-1).unsqueeze(-1)
+ else:
+ # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim]
+ for _ in range(len(lx.size()) - 2):
+ mask = mask.unsqueeze(1)
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ return org_forwarded + lx * self.multiplier * scale
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+
+class LoRAInfModule(LoRAModule):
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ **kwargs,
+ ):
+ # no dropout for inference
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
+
+ self.org_module_ref = [org_module] # 後から参照できるように
+ self.enabled = True
+ self.network: LoRANetwork = None
+
+ def set_network(self, network):
+ self.network = network
+
+ # freezeしてマージする
+ def merge_to(self, sd, dtype, device):
+ # extract weight from org_module
+ org_sd = self.org_module.state_dict()
+ weight = org_sd["weight"]
+ org_dtype = weight.dtype
+ org_device = weight.device
+ weight = weight.to(torch.float) # calc in float
+
+ if dtype is None:
+ dtype = org_dtype
+ if device is None:
+ device = org_device
+
+ # get up/down weight
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
+
+ # merge weight
+ if len(weight.size()) == 2:
+ # linear
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ weight
+ + self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
+ weight = weight + self.multiplier * conved * self.scale
+
+ # set weight to org_module
+ org_sd["weight"] = weight.to(dtype)
+ self.org_module.load_state_dict(org_sd)
+
+ # 復元できるマージのため、このモジュールのweightを返す
+ def get_weight(self, multiplier=None):
+ if multiplier is None:
+ multiplier = self.multiplier
+
+ # get up/down weight from module
+ up_weight = self.lora_up.weight.to(torch.float)
+ down_weight = self.lora_down.weight.to(torch.float)
+
+ # pre-calculated weight
+ if len(down_weight.size()) == 2:
+ # linear
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ self.multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * self.scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ weight = self.multiplier * conved * self.scale
+
+ return weight
+
+ def default_forward(self, x):
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
+ lx = self.lora_down(x)
+ lx = self.lora_up(lx)
+ return self.org_forward(x) + lx * self.multiplier * self.scale
+
+ def forward(self, x):
+ if not self.enabled:
+ return self.org_forward(x)
+ return self.default_forward(x)
+
+
def create_network(
multiplier: float,
network_dim: Optional[int],
diff --git a/networks/network_base.py b/networks/network_base.py
new file mode 100644
index 00000000..d9697562
--- /dev/null
+++ b/networks/network_base.py
@@ -0,0 +1,545 @@
+# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc).
+# Provides architecture detection and a generic AdditionalNetwork class.
+
+import os
+import re
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple, Type, Union
+
+import torch
+from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ArchConfig:
+ unet_target_modules: List[str]
+ te_target_modules: List[str]
+ unet_prefix: str
+ te_prefixes: List[str]
+ default_excludes: List[str] = field(default_factory=list)
+ adapter_target_modules: List[str] = field(default_factory=list)
+ unet_conv_target_modules: List[str] = field(default_factory=list)
+
+
+def detect_arch_config(unet, text_encoders) -> ArchConfig:
+ """Detect architecture from model structure and return ArchConfig."""
+ from library.sdxl_original_unet import SdxlUNet2DConditionModel
+
+ # Check SDXL first
+ if unet is not None and (
+ issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel)
+ ):
+ return ArchConfig(
+ unet_target_modules=["Transformer2DModel"],
+ te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
+ unet_prefix="lora_unet",
+ te_prefixes=["lora_te1", "lora_te2"],
+ default_excludes=[],
+ unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"],
+ )
+
+ # Check Anima: look for Block class in named_modules
+ module_class_names = set()
+ if unet is not None:
+ for module in unet.modules():
+ module_class_names.add(type(module).__name__)
+
+ if "Block" in module_class_names:
+ return ArchConfig(
+ unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"],
+ te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"],
+ unet_prefix="lora_unet",
+ te_prefixes=["lora_te"],
+ default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"],
+ adapter_target_modules=["LLMAdapterTransformerBlock"],
+ )
+
+ raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}")
+
+
+def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]:
+ """Parse a string of key-value pairs separated by commas."""
+ pairs = {}
+ for pair in kv_pair_str.split(","):
+ pair = pair.strip()
+ if not pair:
+ continue
+ if "=" not in pair:
+ logger.warning(f"Invalid format: {pair}, expected 'key=value'")
+ continue
+ key, value = pair.split("=", 1)
+ key = key.strip()
+ value = value.strip()
+ try:
+ pairs[key] = int(value) if is_int else float(value)
+ except ValueError:
+ logger.warning(f"Invalid value for {key}: {value}")
+ return pairs
+
+
+class AdditionalNetwork(torch.nn.Module):
+ """Generic Additional network that supports LoHa, LoKr, and similar module types.
+
+ Constructed with a module_class parameter to inject the specific module type.
+ Based on the lora_anima.py LoRANetwork, generalized for multiple architectures.
+ """
+
+ def __init__(
+ self,
+ text_encoders: list,
+ unet,
+ arch_config: ArchConfig,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ module_class: Type[torch.nn.Module] = None,
+ module_kwargs: Optional[Dict] = None,
+ modules_dim: Optional[Dict[str, int]] = None,
+ modules_alpha: Optional[Dict[str, int]] = None,
+ conv_lora_dim: Optional[int] = None,
+ conv_alpha: Optional[float] = None,
+ exclude_patterns: Optional[List[str]] = None,
+ include_patterns: Optional[List[str]] = None,
+ reg_dims: Optional[Dict[str, int]] = None,
+ reg_lrs: Optional[Dict[str, float]] = None,
+ train_llm_adapter: bool = False,
+ verbose: bool = False,
+ ) -> None:
+ super().__init__()
+ assert module_class is not None, "module_class must be specified"
+
+ self.multiplier = multiplier
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.conv_lora_dim = conv_lora_dim
+ self.conv_alpha = conv_alpha
+ self.train_llm_adapter = train_llm_adapter
+ self.reg_dims = reg_dims
+ self.reg_lrs = reg_lrs
+ self.arch_config = arch_config
+
+ self.loraplus_lr_ratio = None
+ self.loraplus_unet_lr_ratio = None
+ self.loraplus_text_encoder_lr_ratio = None
+
+ if module_kwargs is None:
+ module_kwargs = {}
+
+ if modules_dim is not None:
+ logger.info(f"create {module_class.__name__} network from weights")
+ else:
+ logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ logger.info(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
+ )
+
+ # compile regular expressions
+ def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
+ re_patterns = []
+ if patterns is not None:
+ for pattern in patterns:
+ try:
+ re_pattern = re.compile(pattern)
+ except re.error as e:
+ logger.error(f"Invalid pattern '{pattern}': {e}")
+ continue
+ re_patterns.append(re_pattern)
+ return re_patterns
+
+ exclude_re_patterns = str_to_re_patterns(exclude_patterns)
+ include_re_patterns = str_to_re_patterns(include_patterns)
+
+ # create module instances
+ def create_modules(
+ prefix: str,
+ root_module: torch.nn.Module,
+ target_replace_modules: List[str],
+ default_dim: Optional[int] = None,
+ ) -> Tuple[List[torch.nn.Module], List[str]]:
+ loras = []
+ skipped = []
+ for name, module in root_module.named_modules():
+ if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
+ if target_replace_modules is None:
+ module = root_module
+
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ == "Linear"
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+ if is_linear or is_conv2d:
+ original_name = (name + "." if name else "") + child_name
+ lora_name = f"{prefix}.{original_name}".replace(".", "_")
+
+ # exclude/include filter
+ excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
+ included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
+ if excluded and not included:
+ if verbose:
+ logger.info(f"exclude: {original_name}")
+ continue
+
+ dim = None
+ alpha_val = None
+
+ if modules_dim is not None:
+ if lora_name in modules_dim:
+ dim = modules_dim[lora_name]
+ alpha_val = modules_alpha[lora_name]
+ else:
+ if self.reg_dims is not None:
+ for reg, d in self.reg_dims.items():
+ if re.fullmatch(reg, original_name):
+ dim = d
+ alpha_val = self.alpha
+ logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
+ break
+ # fallback to default dim
+ if dim is None:
+ if is_linear or is_conv2d_1x1:
+ dim = default_dim if default_dim is not None else self.lora_dim
+ alpha_val = self.alpha
+ elif is_conv2d and self.conv_lora_dim is not None:
+ dim = self.conv_lora_dim
+ alpha_val = self.conv_alpha
+
+ if dim is None or dim == 0:
+ if is_linear or is_conv2d_1x1:
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha_val,
+ dropout=dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ **module_kwargs,
+ )
+ lora.original_name = original_name
+ loras.append(lora)
+
+ if target_replace_modules is None:
+ break
+ return loras, skipped
+
+ # Create modules for text encoders
+ self.text_encoder_loras: List[torch.nn.Module] = []
+ skipped_te = []
+ if text_encoders is not None:
+ for i, text_encoder in enumerate(text_encoders):
+ if text_encoder is None:
+ continue
+
+ # Determine prefix for this text encoder
+ if i < len(arch_config.te_prefixes):
+ te_prefix = arch_config.te_prefixes[i]
+ else:
+ te_prefix = arch_config.te_prefixes[0]
+
+ logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):")
+ te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules)
+ logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.")
+ self.text_encoder_loras.extend(te_loras)
+ skipped_te += te_skipped
+
+ # Create modules for UNet/DiT
+ target_modules = list(arch_config.unet_target_modules)
+ if modules_dim is not None or conv_lora_dim is not None:
+ target_modules.extend(arch_config.unet_conv_target_modules)
+ if train_llm_adapter and arch_config.adapter_target_modules:
+ target_modules.extend(arch_config.adapter_target_modules)
+
+ self.unet_loras: List[torch.nn.Module]
+ self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules)
+ logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.")
+
+ if verbose:
+ for lora in self.unet_loras:
+ logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
+
+ skipped = skipped_te + skipped_un
+ if verbose and len(skipped) > 0:
+ logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:")
+ for name in skipped:
+ logger.info(f"\t{name}")
+
+ # assertion: no duplicate names
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ def set_multiplier(self, multiplier):
+ self.multiplier = multiplier
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.multiplier = self.multiplier
+
+ def set_enabled(self, is_enabled):
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.enabled = is_enabled
+
+ def load_weights(self, file):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+
+ info = self.load_state_dict(weights_sd, False)
+ return info
+
+ def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
+ if apply_text_encoder:
+ logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.apply_to()
+ self.add_module(lora.lora_name, lora)
+
+ def is_mergeable(self):
+ return True
+
+ def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
+ apply_text_encoder = apply_unet = False
+ te_prefixes = self.arch_config.te_prefixes
+ unet_prefix = self.arch_config.unet_prefix
+
+ for key in weights_sd.keys():
+ if any(key.startswith(p) for p in te_prefixes):
+ apply_text_encoder = True
+ elif key.startswith(unet_prefix):
+ apply_unet = True
+
+ if apply_text_encoder:
+ logger.info("enable modules for text encoder")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ logger.info("enable modules for UNet/DiT")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ sd_for_lora = {}
+ for key in weights_sd.keys():
+ if key.startswith(lora.lora_name):
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
+ lora.merge_to(sd_for_lora, dtype, device)
+
+ logger.info("weights are merged")
+
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
+ self.loraplus_lr_ratio = loraplus_lr_ratio
+ self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
+ self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
+
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
+ logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
+
+ def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
+ if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
+ text_encoder_lr = [default_lr]
+ elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
+ text_encoder_lr = [float(text_encoder_lr)]
+ elif len(text_encoder_lr) == 1:
+ pass # already a list with one element
+
+ self.requires_grad_(True)
+
+ all_params = []
+ lr_descriptions = []
+
+ def assemble_params(loras, lr, loraplus_ratio):
+ param_groups = {"lora": {}, "plus": {}}
+ reg_groups = {}
+ reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
+
+ for lora in loras:
+ matched_reg_lr = None
+ for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
+ if re.fullmatch(regex_str, lora.original_name):
+ matched_reg_lr = (i, reg_lr)
+ logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
+ break
+
+ for name, param in lora.named_parameters():
+ if matched_reg_lr is not None:
+ reg_idx, reg_lr = matched_reg_lr
+ group_key = f"reg_lr_{reg_idx}"
+ if group_key not in reg_groups:
+ reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
+ # LoRA+ detection: check for "up" weight parameters
+ if loraplus_ratio is not None and self._is_plus_param(name):
+ reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
+ else:
+ reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
+ continue
+
+ if loraplus_ratio is not None and self._is_plus_param(name):
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
+ else:
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
+
+ params = []
+ descriptions = []
+ for group_key, group in reg_groups.items():
+ reg_lr = group["lr"]
+ for key in ("lora", "plus"):
+ param_data = {"params": group[key].values()}
+ if len(param_data["params"]) == 0:
+ continue
+ if key == "plus":
+ param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
+ else:
+ param_data["lr"] = reg_lr
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
+ logger.info("NO LR skipping!")
+ continue
+ params.append(param_data)
+ desc = f"reg_lr_{group_key.split('_')[-1]}"
+ descriptions.append(desc + (" plus" if key == "plus" else ""))
+
+ for key in param_groups.keys():
+ param_data = {"params": param_groups[key].values()}
+ if len(param_data["params"]) == 0:
+ continue
+ if lr is not None:
+ if key == "plus":
+ param_data["lr"] = lr * loraplus_ratio
+ else:
+ param_data["lr"] = lr
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
+ logger.info("NO LR skipping!")
+ continue
+ params.append(param_data)
+ descriptions.append("plus" if key == "plus" else "")
+ return params, descriptions
+
+ if self.text_encoder_loras:
+ loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
+ # Group TE loras by prefix
+ for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes):
+ te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)]
+ if len(te_loras) > 0:
+ te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0]
+ logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}")
+ params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio)
+ all_params.extend(params)
+ lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions])
+
+ if self.unet_loras:
+ params, descriptions = assemble_params(
+ self.unet_loras,
+ unet_lr if unet_lr is not None else default_lr,
+ self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
+ )
+ all_params.extend(params)
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
+
+ return all_params, lr_descriptions
+
+ def _is_plus_param(self, name: str) -> bool:
+ """Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+.
+
+ For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor).
+ Override in subclass if needed. Default: check for common 'up' patterns.
+ """
+ return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name
+
+ def enable_gradient_checkpointing(self):
+ pass # not supported
+
+ def prepare_grad_etc(self, text_encoder, unet):
+ self.requires_grad_(True)
+
+ def on_epoch_start(self, text_encoder, unet):
+ self.train()
+
+ def get_trainable_params(self):
+ return self.parameters()
+
+ def save_weights(self, file, dtype, metadata):
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+
+ if dtype is not None:
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ state_dict[key] = v
+
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+ from library import train_util
+
+ if metadata is None:
+ metadata = {}
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ save_file(state_dict, file, metadata)
+ else:
+ torch.save(state_dict, file)
+
+ def backup_weights(self):
+ loras = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not hasattr(org_module, "_lora_org_weight"):
+ sd = org_module.state_dict()
+ org_module._lora_org_weight = sd["weight"].detach().clone()
+ org_module._lora_restored = True
+
+ def restore_weights(self):
+ loras = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ if not org_module._lora_restored:
+ sd = org_module.state_dict()
+ sd["weight"] = org_module._lora_org_weight
+ org_module.load_state_dict(sd)
+ org_module._lora_restored = True
+
+ def pre_calculation(self):
+ loras = self.text_encoder_loras + self.unet_loras
+ for lora in loras:
+ org_module = lora.org_module_ref[0]
+ sd = org_module.state_dict()
+
+ org_weight = sd["weight"]
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
+ sd["weight"] = org_weight + lora_weight
+ assert sd["weight"].shape == org_weight.shape
+ org_module.load_state_dict(sd)
+
+ org_module._lora_restored = False
+ lora.enabled = False