mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add block dim(rank) feature
This commit is contained in:
62
README.md
62
README.md
@@ -127,8 +127,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
- 3 Apr. 2023, 2023/4/3:
|
- 4 Apr. 2023, 2023/4/4:
|
||||||
- Add `--network_args` option to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for your great contribution!
|
- Add options to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for the great contribution!
|
||||||
- Specify the weights of 25 blocks for the full model.
|
- Specify the weights of 25 blocks for the full model.
|
||||||
- No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values for the argument for consistency.
|
- No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values for the argument for consistency.
|
||||||
- Specify the following arguments with `--network_args`.
|
- Specify the following arguments with `--network_args`.
|
||||||
@@ -138,10 +138,19 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`.
|
- `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`.
|
||||||
- `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.
|
- `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.
|
||||||
- If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created.
|
- If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created.
|
||||||
|
- `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0.
|
||||||
|
|
||||||
- 階層別学習率を `train_network.py` で指定できるようにしました。u-haru 氏の多大な貢献に感謝します。
|
- Add options to `train_network.py` to specify block dims (ranks) for variable rank.
|
||||||
|
- Specify 25 values for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values always.
|
||||||
|
- Specify the following arguments with `--network_args`.
|
||||||
|
- `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`.
|
||||||
|
- `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.
|
||||||
|
- `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block.
|
||||||
|
- `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used.
|
||||||
|
|
||||||
|
- 階層別学習率を `train_network.py` で指定できるようになりました。u-haru 氏の多大な貢献に感謝します。
|
||||||
- フルモデルの25個のブロックの重みを指定できます。
|
- フルモデルの25個のブロックの重みを指定できます。
|
||||||
- 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合は一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
|
- 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
|
||||||
-`--network_args` で以下の引数を指定してください。
|
-`--network_args` で以下の引数を指定してください。
|
||||||
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
|
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
|
||||||
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
|
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
|
||||||
@@ -149,33 +158,30 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
|
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
|
||||||
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
|
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
|
||||||
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
|
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
|
||||||
|
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
|
||||||
|
|
||||||
|
- 階層別dim (rank)を `train_network.py` で指定できるようになりました。
|
||||||
|
- フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
|
||||||
|
- `--network_args` で以下の引数を指定してください。
|
||||||
|
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
|
||||||
|
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。
|
||||||
|
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
|
||||||
|
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
|
||||||
|
|
||||||
- 1 Apr. 2023, 2023/4/1:
|
- 階層別学習率コマンドライン指定例 / Examples of block learning rate command line specification:
|
||||||
- Fix an issue that `merge_lora.py` does not work with the latest version.
|
|
||||||
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
|
` --network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"`
|
||||||
- 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。
|
|
||||||
- `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。
|
` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"`
|
||||||
- 31 Mar. 2023, 2023/3/31:
|
|
||||||
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
|
- 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification:
|
||||||
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
|
||||||
- `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。
|
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"`
|
||||||
- `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354)
|
|
||||||
- 30 Mar. 2023, 2023/3/30:
|
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`
|
||||||
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
|
|
||||||
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
|
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`
|
||||||
- Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported.
|
|
||||||
- Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option.
|
|
||||||
- Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec!
|
|
||||||
- Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option.
|
|
||||||
|
|
||||||
- [P+](https://prompt-plus.github.io/) の学習に対応しました。jakaline-dev氏に感謝します。
|
|
||||||
- 詳細は [#327](https://github.com/kohya-ss/sd-scripts/pull/327) をご参照ください。
|
|
||||||
- 学習には `train_textual_inversion_XTI.py` を使用します。使用法は `train_textual_inversion.py` とほぼ同じです。た
|
|
||||||
だし学習中のサンプル生成には対応していません。
|
|
||||||
- 画像生成には `gen_img_diffusers.py` を使用してください(Web UIは対応していないと思われます)。`--XTI_embeddings` オプションで学習したembeddingを指定してください。
|
|
||||||
- `train_network.py` で起動時のRAM使用量を削減しました。[#332](https://github.com/kohya-ss/sd-scripts/pull/332) guaneec氏に感謝します。
|
|
||||||
- `gen_img_diffusers.py` でLoRAの事前マージに対応しました。`--network_merge` オプションを指定してください。なおプロンプトオプションの `--am` は使用できなくなります。
|
|
||||||
|
|
||||||
## Sample image generation during training
|
## Sample image generation during training
|
||||||
A prompt file might look like this, for example
|
A prompt file might look like this, for example
|
||||||
|
|||||||
@@ -2275,7 +2275,7 @@ def main(args):
|
|||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network = imported_module.create_network_from_weights(
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
|
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -2285,6 +2285,8 @@ def main(args):
|
|||||||
|
|
||||||
if not args.network_merge:
|
if not args.network_merge:
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
|
info = network.load_state_dict(weights_sd, False)
|
||||||
|
print(f"weights are loaded: {info}")
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
@@ -2292,7 +2294,7 @@ def main(args):
|
|||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
else:
|
else:
|
||||||
network.merge_to(text_encoder, unet, dtype, device)
|
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
networks = []
|
networks = []
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ def svd(args):
|
|||||||
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
||||||
|
|
||||||
# load state dict to LoRA and save it
|
# load state dict to LoRA and save it
|
||||||
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||||
|
|
||||||
info = lora_network_save.load_state_dict(lora_sd)
|
info = lora_network_save.load_state_dict(lora_sd)
|
||||||
|
|||||||
655
networks/lora.py
655
networks/lora.py
@@ -143,6 +143,8 @@ class LoRAModule(torch.nn.Module):
|
|||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = 1.0
|
||||||
|
|
||||||
# extract dim/alpha for conv2d, and block dim
|
# extract dim/alpha for conv2d, and block dim
|
||||||
conv_dim = kwargs.get("conv_dim", None)
|
conv_dim = kwargs.get("conv_dim", None)
|
||||||
@@ -154,34 +156,50 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
else:
|
else:
|
||||||
conv_alpha = float(conv_alpha)
|
conv_alpha = float(conv_alpha)
|
||||||
|
|
||||||
"""
|
# block dim/alpha/lr
|
||||||
block_dims = kwargs.get("block_dims")
|
block_dims = kwargs.get("block_dims", None)
|
||||||
block_alphas = None
|
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||||
|
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||||
|
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||||
|
|
||||||
|
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
||||||
|
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
||||||
|
block_alphas = kwargs.get("block_alphas", None)
|
||||||
|
conv_block_dims = kwargs.get("conv_block_dims", None)
|
||||||
|
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
||||||
|
|
||||||
|
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
||||||
|
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
# extract learning rate weight for each block
|
||||||
|
if down_lr_weight is not None:
|
||||||
|
# if some parameters are not set, use zero
|
||||||
|
if "," in down_lr_weight:
|
||||||
|
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||||
|
|
||||||
|
if mid_lr_weight is not None:
|
||||||
|
mid_lr_weight = float(mid_lr_weight)
|
||||||
|
|
||||||
|
if up_lr_weight is not None:
|
||||||
|
if "," in up_lr_weight:
|
||||||
|
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||||
|
|
||||||
|
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||||
|
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove block dim/alpha without learning rate
|
||||||
|
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
||||||
|
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
)
|
||||||
|
|
||||||
if block_dims is not None:
|
|
||||||
block_dims = [int(d) for d in block_dims.split(',')]
|
|
||||||
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
|
||||||
block_alphas = kwargs.get("block_alphas")
|
|
||||||
if block_alphas is None:
|
|
||||||
block_alphas = [1] * len(block_dims)
|
|
||||||
else:
|
else:
|
||||||
block_alphas = [int(a) for a in block_alphas(',')]
|
block_alphas = None
|
||||||
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
conv_block_dims = None
|
||||||
|
|
||||||
conv_block_dims = kwargs.get("conv_block_dims")
|
|
||||||
conv_block_alphas = None
|
conv_block_alphas = None
|
||||||
|
|
||||||
if conv_block_dims is not None:
|
# すごく引数が多いな ( ^ω^)・・・
|
||||||
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
|
||||||
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
|
||||||
conv_block_alphas = kwargs.get("conv_block_alphas")
|
|
||||||
if conv_block_alphas is None:
|
|
||||||
conv_block_alphas = [1] * len(conv_block_dims)
|
|
||||||
else:
|
|
||||||
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
|
||||||
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
|
||||||
"""
|
|
||||||
|
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
unet,
|
unet,
|
||||||
@@ -190,271 +208,95 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
alpha=network_alpha,
|
alpha=network_alpha,
|
||||||
conv_lora_dim=conv_dim,
|
conv_lora_dim=conv_dim,
|
||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
|
block_dims=block_dims,
|
||||||
|
block_alphas=block_alphas,
|
||||||
|
conv_block_dims=conv_block_dims,
|
||||||
|
conv_block_alphas=conv_block_alphas,
|
||||||
|
varbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if some parameters are not set, use zero
|
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||||
up_lr_weight = kwargs.get("up_lr_weight", None)
|
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||||
if up_lr_weight is not None:
|
|
||||||
if "," in up_lr_weight:
|
|
||||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
|
||||||
|
|
||||||
down_lr_weight = kwargs.get("down_lr_weight", None)
|
|
||||||
if down_lr_weight is not None:
|
|
||||||
if "," in down_lr_weight:
|
|
||||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
|
||||||
|
|
||||||
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
|
||||||
if mid_lr_weight is not None:
|
|
||||||
mid_lr_weight = float(mid_lr_weight)
|
|
||||||
|
|
||||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)))
|
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
# このメソッドは外部から呼び出される可能性を考慮しておく
|
||||||
if weights_sd is None:
|
# network_dim, network_alpha にはデフォルト値が入っている。
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
||||||
from safetensors.torch import load_file, safe_open
|
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
||||||
|
def get_block_dims_and_alphas(
|
||||||
|
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||||
|
):
|
||||||
|
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
||||||
|
|
||||||
weights_sd = load_file(file)
|
def parse_ints(s):
|
||||||
else:
|
return [int(i) for i in s.split(",")]
|
||||||
weights_sd = torch.load(file, map_location="cpu")
|
|
||||||
|
|
||||||
# get dim/alpha mapping
|
def parse_floats(s):
|
||||||
modules_dim = {}
|
return [float(i) for i in s.split(",")]
|
||||||
modules_alpha = {}
|
|
||||||
for key, value in weights_sd.items():
|
|
||||||
if "." not in key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora_name = key.split(".")[0]
|
# block_dimsとblock_alphasをパースする。必ず値が入る
|
||||||
if "alpha" in key:
|
if block_dims is not None:
|
||||||
modules_alpha[lora_name] = value
|
block_dims = parse_ints(block_dims)
|
||||||
elif "lora_down" in key:
|
|
||||||
dim = value.size()[0]
|
|
||||||
modules_dim[lora_name] = dim
|
|
||||||
# print(lora_name, value.size(), dim)
|
|
||||||
|
|
||||||
# support old LoRA without alpha
|
|
||||||
for key in modules_dim.keys():
|
|
||||||
if key not in modules_alpha:
|
|
||||||
modules_alpha = modules_dim[key]
|
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
|
||||||
network.weights_sd = weights_sd
|
|
||||||
return network
|
|
||||||
|
|
||||||
|
|
||||||
class LoRANetwork(torch.nn.Module):
|
|
||||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
|
||||||
|
|
||||||
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
|
||||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
|
||||||
LORA_PREFIX_UNET = "lora_unet"
|
|
||||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_encoder,
|
|
||||||
unet,
|
|
||||||
multiplier=1.0,
|
|
||||||
lora_dim=4,
|
|
||||||
alpha=1,
|
|
||||||
conv_lora_dim=None,
|
|
||||||
conv_alpha=None,
|
|
||||||
modules_dim=None,
|
|
||||||
modules_alpha=None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.multiplier = multiplier
|
|
||||||
|
|
||||||
self.lora_dim = lora_dim
|
|
||||||
self.alpha = alpha
|
|
||||||
self.conv_lora_dim = conv_lora_dim
|
|
||||||
self.conv_alpha = conv_alpha
|
|
||||||
|
|
||||||
if modules_dim is not None:
|
|
||||||
print(f"create LoRA network from weights")
|
|
||||||
else:
|
|
||||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
|
||||||
|
|
||||||
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
|
||||||
if self.apply_to_conv2d_3x3:
|
|
||||||
if self.conv_alpha is None:
|
|
||||||
self.conv_alpha = self.alpha
|
|
||||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
|
||||||
|
|
||||||
# create module instances
|
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
|
||||||
loras = []
|
|
||||||
for name, module in root_module.named_modules():
|
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
|
||||||
# TODO get block index here
|
|
||||||
for child_name, child_module in module.named_modules():
|
|
||||||
is_linear = child_module.__class__.__name__ == "Linear"
|
|
||||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
|
||||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
|
||||||
if is_linear or is_conv2d:
|
|
||||||
lora_name = prefix + "." + name + "." + child_name
|
|
||||||
lora_name = lora_name.replace(".", "_")
|
|
||||||
|
|
||||||
if modules_dim is not None:
|
|
||||||
if lora_name not in modules_dim:
|
|
||||||
continue # no LoRA module in this weights file
|
|
||||||
dim = modules_dim[lora_name]
|
|
||||||
alpha = modules_alpha[lora_name]
|
|
||||||
else:
|
|
||||||
if is_linear or is_conv2d_1x1:
|
|
||||||
dim = self.lora_dim
|
|
||||||
alpha = self.alpha
|
|
||||||
elif self.apply_to_conv2d_3x3:
|
|
||||||
dim = self.conv_lora_dim
|
|
||||||
alpha = self.conv_alpha
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
|
||||||
loras.append(lora)
|
|
||||||
return loras
|
|
||||||
|
|
||||||
self.text_encoder_loras = create_modules(
|
|
||||||
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
|
||||||
)
|
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
|
||||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
|
||||||
if modules_dim is not None or self.conv_lora_dim is not None:
|
|
||||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
|
||||||
|
|
||||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
|
||||||
|
|
||||||
self.weights_sd = None
|
|
||||||
|
|
||||||
self.up_lr_weight: List[float] = None
|
|
||||||
self.down_lr_weight: List[float] = None
|
|
||||||
self.mid_lr_weight: float = None
|
|
||||||
self.block_lr = False
|
|
||||||
|
|
||||||
# assertion
|
|
||||||
names = set()
|
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
|
||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
|
||||||
names.add(lora.lora_name)
|
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
|
||||||
self.multiplier = multiplier
|
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
|
||||||
lora.multiplier = self.multiplier
|
|
||||||
|
|
||||||
def load_weights(self, file):
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
|
||||||
from safetensors.torch import load_file, safe_open
|
|
||||||
|
|
||||||
self.weights_sd = load_file(file)
|
|
||||||
else:
|
|
||||||
self.weights_sd = torch.load(file, map_location="cpu")
|
|
||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
|
||||||
if self.weights_sd:
|
|
||||||
weights_has_text_encoder = weights_has_unet = False
|
|
||||||
for key in self.weights_sd.keys():
|
|
||||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
|
||||||
weights_has_text_encoder = True
|
|
||||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
|
||||||
weights_has_unet = True
|
|
||||||
|
|
||||||
if apply_text_encoder is None:
|
|
||||||
apply_text_encoder = weights_has_text_encoder
|
|
||||||
else:
|
|
||||||
assert (
|
assert (
|
||||||
apply_text_encoder == weights_has_text_encoder
|
len(block_dims) == num_total_blocks
|
||||||
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||||
|
|
||||||
if apply_unet is None:
|
|
||||||
apply_unet = weights_has_unet
|
|
||||||
else:
|
else:
|
||||||
|
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||||
|
block_dims = [network_dim] * num_total_blocks
|
||||||
|
|
||||||
|
if block_alphas is not None:
|
||||||
|
block_alphas = parse_floats(block_alphas)
|
||||||
assert (
|
assert (
|
||||||
apply_unet == weights_has_unet
|
len(block_alphas) == num_total_blocks
|
||||||
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
|
||||||
|
|
||||||
if apply_text_encoder:
|
|
||||||
print("enable LoRA for text encoder")
|
|
||||||
else:
|
|
||||||
self.text_encoder_loras = []
|
|
||||||
|
|
||||||
if apply_unet:
|
|
||||||
print("enable LoRA for U-Net")
|
|
||||||
else:
|
|
||||||
self.unet_loras = []
|
|
||||||
|
|
||||||
skipped = []
|
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
|
||||||
if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight
|
|
||||||
skipped.append(lora.lora_name)
|
|
||||||
continue
|
|
||||||
lora.apply_to()
|
|
||||||
self.add_module(lora.lora_name, lora)
|
|
||||||
|
|
||||||
if len(skipped) > 0:
|
|
||||||
print(
|
print(
|
||||||
f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||||
)
|
)
|
||||||
for name in skipped:
|
block_alphas = [network_alpha] * num_total_blocks
|
||||||
print(f"\t{name}")
|
|
||||||
|
|
||||||
if self.weights_sd:
|
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
||||||
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
if conv_block_dims is not None:
|
||||||
info = self.load_state_dict(self.weights_sd, False)
|
conv_block_dims = parse_ints(conv_block_dims)
|
||||||
print(f"weights are loaded: {info}")
|
assert (
|
||||||
|
len(conv_block_dims) == num_total_blocks
|
||||||
|
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
||||||
|
|
||||||
# TODO refactor to common function with apply_to
|
if conv_block_alphas is not None:
|
||||||
def merge_to(self, text_encoder, unet, dtype, device):
|
conv_block_alphas = parse_floats(conv_block_alphas)
|
||||||
assert self.weights_sd is not None, "weights are not loaded"
|
assert (
|
||||||
|
len(conv_block_alphas) == num_total_blocks
|
||||||
apply_text_encoder = apply_unet = False
|
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
||||||
for key in self.weights_sd.keys():
|
|
||||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
|
||||||
apply_text_encoder = True
|
|
||||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
|
||||||
apply_unet = True
|
|
||||||
|
|
||||||
if apply_text_encoder:
|
|
||||||
print("enable LoRA for text encoder")
|
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
if conv_alpha is None:
|
||||||
|
conv_alpha = 1.0
|
||||||
if apply_unet:
|
print(
|
||||||
print("enable LoRA for U-Net")
|
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||||
|
)
|
||||||
|
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
if conv_dim is not None:
|
||||||
|
print(
|
||||||
|
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||||
|
)
|
||||||
|
conv_block_dims = [conv_dim] * num_total_blocks
|
||||||
|
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||||
|
else:
|
||||||
|
conv_block_dims = None
|
||||||
|
conv_block_alphas = None
|
||||||
|
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||||
sd_for_lora = {}
|
|
||||||
for key in self.weights_sd.keys():
|
|
||||||
if key.startswith(lora.lora_name):
|
|
||||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
|
||||||
lora.merge_to(sd_for_lora, dtype, device)
|
|
||||||
|
|
||||||
print(f"weights are merged")
|
|
||||||
|
|
||||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
||||||
def set_block_lr_weight(
|
def get_block_lr_weight(
|
||||||
self,
|
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
||||||
up_lr_weight: Union[List[float], str] = None,
|
) -> Tuple[List[float], List[float], List[float]]:
|
||||||
mid_lr_weight: float = None,
|
|
||||||
down_lr_weight: Union[List[float], str] = None,
|
|
||||||
zero_threshold: float = 0.0,
|
|
||||||
):
|
|
||||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||||
return
|
return None, None, None
|
||||||
|
|
||||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
@@ -504,32 +346,58 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||||
print("apply block learning rate / 階層別学習率を適用します。")
|
print("apply block learning rate / 階層別学習率を適用します。")
|
||||||
self.block_lr = True
|
|
||||||
|
|
||||||
if down_lr_weight != None:
|
if down_lr_weight != None:
|
||||||
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight)
|
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||||
else:
|
else:
|
||||||
print("down_lr_weight: all 1.0, すべて1.0")
|
print("down_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
if mid_lr_weight != None:
|
if mid_lr_weight != None:
|
||||||
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||||
print("mid_lr_weight:", self.mid_lr_weight)
|
print("mid_lr_weight:", mid_lr_weight)
|
||||||
else:
|
else:
|
||||||
print("mid_lr_weight: 1.0")
|
print("mid_lr_weight: 1.0")
|
||||||
|
|
||||||
if up_lr_weight != None:
|
if up_lr_weight != None:
|
||||||
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight)
|
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||||
else:
|
else:
|
||||||
print("up_lr_weight: all 1.0, すべて1.0")
|
print("up_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
return
|
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
|
||||||
def get_block_index(self, lora: LoRAModule) -> int:
|
|
||||||
|
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
||||||
|
def remove_block_dims_and_alphas(
|
||||||
|
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
):
|
||||||
|
# set 0 to block dim without learning rate to remove the block
|
||||||
|
if down_lr_weight != None:
|
||||||
|
for i, lr in enumerate(down_lr_weight):
|
||||||
|
if lr == 0:
|
||||||
|
block_dims[i] = 0
|
||||||
|
if conv_block_dims is not None:
|
||||||
|
conv_block_dims[i] = 0
|
||||||
|
if mid_lr_weight != None:
|
||||||
|
if mid_lr_weight == 0:
|
||||||
|
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||||
|
if conv_block_dims is not None:
|
||||||
|
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||||
|
if up_lr_weight != None:
|
||||||
|
for i, lr in enumerate(up_lr_weight):
|
||||||
|
if lr == 0:
|
||||||
|
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||||
|
if conv_block_dims is not None:
|
||||||
|
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||||
|
|
||||||
|
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||||
|
|
||||||
|
|
||||||
|
# 外部から呼び出す可能性を考慮しておく
|
||||||
|
def get_block_index(lora_name: str) -> int:
|
||||||
block_idx = -1 # invalid lora name
|
block_idx = -1 # invalid lora name
|
||||||
|
|
||||||
m = RE_UPDOWN.search(lora.lora_name)
|
m = RE_UPDOWN.search(lora_name)
|
||||||
if m:
|
if m:
|
||||||
g = m.groups()
|
g = m.groups()
|
||||||
i = int(g[1])
|
i = int(g[1])
|
||||||
@@ -546,14 +414,243 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
elif g[0] == "up":
|
elif g[0] == "up":
|
||||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||||
|
|
||||||
elif "mid_block_" in lora.lora_name:
|
elif "mid_block_" in lora_name:
|
||||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||||
|
|
||||||
return block_idx
|
return block_idx
|
||||||
|
|
||||||
|
|
||||||
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
||||||
|
if weights_sd is None:
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
# get dim/alpha mapping
|
||||||
|
modules_dim = {}
|
||||||
|
modules_alpha = {}
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
if "." not in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
if "alpha" in key:
|
||||||
|
modules_alpha[lora_name] = value
|
||||||
|
elif "lora_down" in key:
|
||||||
|
dim = value.size()[0]
|
||||||
|
modules_dim[lora_name] = dim
|
||||||
|
# print(lora_name, value.size(), dim)
|
||||||
|
|
||||||
|
# support old LoRA without alpha
|
||||||
|
for key in modules_dim.keys():
|
||||||
|
if key not in modules_alpha:
|
||||||
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||||
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
|
class LoRANetwork(torch.nn.Module):
|
||||||
|
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
|
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
||||||
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
conv_lora_dim=None,
|
||||||
|
conv_alpha=None,
|
||||||
|
block_dims=None,
|
||||||
|
block_alphas=None,
|
||||||
|
conv_block_dims=None,
|
||||||
|
conv_block_alphas=None,
|
||||||
|
modules_dim=None,
|
||||||
|
modules_alpha=None,
|
||||||
|
varbose=False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||||
|
1. lora_dimとalphaを指定
|
||||||
|
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
||||||
|
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
||||||
|
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
||||||
|
5. modules_dimとmodules_alphaを指定 (推論用)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
self.alpha = alpha
|
||||||
|
self.conv_lora_dim = conv_lora_dim
|
||||||
|
self.conv_alpha = conv_alpha
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
print(f"create LoRA network from weights")
|
||||||
|
elif block_dims is not None:
|
||||||
|
print(f"create LoRA network from block_dims")
|
||||||
|
print(f"block_dims: {block_dims}")
|
||||||
|
print(f"block_alphas: {block_alphas}")
|
||||||
|
if conv_block_dims is not None:
|
||||||
|
print(f"conv_block_dims: {conv_block_dims}")
|
||||||
|
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
|
else:
|
||||||
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
|
if self.conv_lora_dim is not None:
|
||||||
|
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
|
# create module instances
|
||||||
|
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||||
|
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||||
|
loras = []
|
||||||
|
skipped = []
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
if is_linear or is_conv2d:
|
||||||
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
|
dim = None
|
||||||
|
alpha = None
|
||||||
|
if modules_dim is not None:
|
||||||
|
if lora_name in modules_dim:
|
||||||
|
dim = modules_dim[lora_name]
|
||||||
|
alpha = modules_alpha[lora_name]
|
||||||
|
elif is_unet and block_dims is not None:
|
||||||
|
block_idx = get_block_index(lora_name)
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = block_dims[block_idx]
|
||||||
|
alpha = block_alphas[block_idx]
|
||||||
|
elif conv_block_dims is not None:
|
||||||
|
dim = conv_block_dims[block_idx]
|
||||||
|
alpha = conv_block_alphas[block_idx]
|
||||||
|
else:
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
elif self.conv_lora_dim is not None:
|
||||||
|
dim = self.conv_lora_dim
|
||||||
|
alpha = self.conv_alpha
|
||||||
|
|
||||||
|
if dim is None or dim == 0:
|
||||||
|
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||||
|
skipped.append(lora_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
||||||
|
loras.append(lora)
|
||||||
|
return loras, skipped
|
||||||
|
|
||||||
|
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
|
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||||
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
|
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
||||||
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
|
skipped = skipped_te + skipped_un
|
||||||
|
if varbose and len(skipped) > 0:
|
||||||
|
print(
|
||||||
|
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||||
|
)
|
||||||
|
for name in skipped:
|
||||||
|
print(f"\t{name}")
|
||||||
|
|
||||||
|
self.up_lr_weight: List[float] = None
|
||||||
|
self.down_lr_weight: List[float] = None
|
||||||
|
self.mid_lr_weight: float = None
|
||||||
|
self.block_lr = False
|
||||||
|
|
||||||
|
# assertion
|
||||||
|
names = set()
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
|
if apply_text_encoder:
|
||||||
|
print("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
print("enable LoRA for U-Net")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.apply_to()
|
||||||
|
self.add_module(lora.lora_name, lora)
|
||||||
|
|
||||||
|
# TODO refactor to common function with apply_to
|
||||||
|
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||||
|
apply_text_encoder = apply_unet = False
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||||
|
apply_text_encoder = True
|
||||||
|
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||||
|
apply_unet = True
|
||||||
|
|
||||||
|
if apply_text_encoder:
|
||||||
|
print("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
print("enable LoRA for U-Net")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(lora.lora_name):
|
||||||
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||||
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
|
print(f"weights are merged")
|
||||||
|
|
||||||
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||||
|
def set_block_lr_weight(
|
||||||
|
self,
|
||||||
|
up_lr_weight: List[float] = None,
|
||||||
|
mid_lr_weight: float = None,
|
||||||
|
down_lr_weight: List[float] = None,
|
||||||
|
):
|
||||||
|
self.block_lr = True
|
||||||
|
self.down_lr_weight = down_lr_weight
|
||||||
|
self.mid_lr_weight = mid_lr_weight
|
||||||
|
self.up_lr_weight = up_lr_weight
|
||||||
|
|
||||||
def get_lr_weight(self, lora: LoRAModule) -> float:
|
def get_lr_weight(self, lora: LoRAModule) -> float:
|
||||||
lr_weight = 1.0
|
lr_weight = 1.0
|
||||||
block_idx = self.get_block_index(lora)
|
block_idx = get_block_index(lora.lora_name)
|
||||||
if block_idx < 0:
|
if block_idx < 0:
|
||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
@@ -590,7 +687,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||||
block_idx_to_lora = {}
|
block_idx_to_lora = {}
|
||||||
for lora in self.unet_loras:
|
for lora in self.unet_loras:
|
||||||
idx = self.get_block_index(lora)
|
idx = get_block_index(lora.lora_name)
|
||||||
if idx not in block_idx_to_lora:
|
if idx not in block_idx_to_lora:
|
||||||
block_idx_to_lora[idx] = []
|
block_idx_to_lora[idx] = []
|
||||||
block_idx_to_lora[idx].append(lora)
|
block_idx_to_lora[idx].append(lora)
|
||||||
|
|||||||
Reference in New Issue
Block a user