Add block dim(rank) feature

This commit is contained in:
Kohya S
2023-04-03 21:19:49 +09:00
parent 817a9268ff
commit 6134619998
4 changed files with 361 additions and 256 deletions

View File

@@ -127,8 +127,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
- 3 Apr. 2023, 2023/4/3:
- Add `--network_args` option to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for your great contribution!
- 4 Apr. 2023, 2023/4/4:
- Add options to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for the great contribution!
- Specify the weights of 25 blocks for the full model.
- No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values for the argument for consistency.
- Specify the following arguments with `--network_args`.
@@ -138,10 +138,19 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`.
- `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.
- If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created.
- `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0.
- 階層別学習率を `train_network.py` で指定できるようにしました。u-haru 氏の多大な貢献に感謝します。
- Add options to `train_network.py` to specify block dims (ranks) for variable rank.
- Specify 25 values for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values always.
- Specify the following arguments with `--network_args`.
- `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`.
- `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.
- `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block.
- `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used.
- 階層別学習率を `train_network.py` で指定できるようになりました。u-haru 氏の多大な貢献に感謝します。
- フルモデルの25個のブロックの重みを指定できます。
- 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
- 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
-`--network_args` で以下の引数を指定してください。
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
@@ -149,33 +158,30 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
- 1 Apr. 2023, 2023/4/1:
- Fix an issue that `merge_lora.py` does not work with the latest version.
- Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights.
- 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました
- `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました
- 31 Mar. 2023, 2023/3/31:
- Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`.
- Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354)
- `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。
- `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354)
- 30 Mar. 2023, 2023/3/30:
- Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev!
- See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details.
- Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported.
- Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option.
- Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec!
- Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option.
- 階層別dim (rank)を `train_network.py` で指定できるようになりました。
- フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
- `--network_args` で以下の引数を指定してください。
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
- 階層別学習率コマンドライン指定例 / Examples of block learning rate command line specification:
` --network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"`
` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"`
- 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification:
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"`
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`
- [P+](https://prompt-plus.github.io/) の学習に対応しました。jakaline-dev氏に感謝します。
- 詳細は [#327](https://github.com/kohya-ss/sd-scripts/pull/327) をご参照ください。
- 学習には `train_textual_inversion_XTI.py` を使用します。使用法は `train_textual_inversion.py` とほぼ同じです。た
だし学習中のサンプル生成には対応していません。
- 画像生成には `gen_img_diffusers.py` を使用してくださいWeb UIは対応していないと思われます。`--XTI_embeddings` オプションで学習したembeddingを指定してください。
- `train_network.py` で起動時のRAM使用量を削減しました。[#332](https://github.com/kohya-ss/sd-scripts/pull/332) guaneec氏に感謝します。
- `gen_img_diffusers.py` でLoRAの事前マージに対応しました。`--network_merge` オプションを指定してください。なおプロンプトオプションの `--am` は使用できなくなります。
## Sample image generation during training
A prompt file might look like this, for example

View File

@@ -2275,7 +2275,7 @@ def main(args):
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network = imported_module.create_network_from_weights(
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
)
else:
@@ -2285,6 +2285,8 @@ def main(args):
if not args.network_merge:
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False)
print(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
@@ -2292,7 +2294,7 @@ def main(args):
networks.append(network)
else:
network.merge_to(text_encoder, unet, dtype, device)
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
else:
networks = []

View File

@@ -145,8 +145,8 @@ def svd(args):
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
# load state dict to LoRA and save it
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")

View File

@@ -143,6 +143,8 @@ class LoRAModule(torch.nn.Module):
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
@@ -154,34 +156,50 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
else:
conv_alpha = float(conv_alpha)
"""
block_dims = kwargs.get("block_dims")
block_alphas = None
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
down_lr_weight = kwargs.get("down_lr_weight", None)
mid_lr_weight = kwargs.get("mid_lr_weight", None)
up_lr_weight = kwargs.get("up_lr_weight", None)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
block_alphas = kwargs.get("block_alphas", None)
conv_block_dims = kwargs.get("conv_block_dims", None)
conv_block_alphas = kwargs.get("conv_block_alphas", None)
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
)
if block_dims is not None:
block_dims = [int(d) for d in block_dims.split(',')]
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
block_alphas = kwargs.get("block_alphas")
if block_alphas is None:
block_alphas = [1] * len(block_dims)
else:
block_alphas = [int(a) for a in block_alphas(',')]
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
conv_block_dims = kwargs.get("conv_block_dims")
conv_block_alphas = None
if conv_block_dims is not None:
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
conv_block_alphas = kwargs.get("conv_block_alphas")
if conv_block_alphas is None:
conv_block_alphas = [1] * len(conv_block_dims)
else:
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
"""
block_alphas = None
conv_block_dims = None
conv_block_alphas = None
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoder,
unet,
@@ -190,28 +208,219 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
alpha=network_alpha,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
block_dims=block_dims,
block_alphas=block_alphas,
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
)
# if some parameters are not set, use zero
up_lr_weight = kwargs.get("up_lr_weight", None)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight = kwargs.get("down_lr_weight", None)
if down_lr_weight is not None:
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
mid_lr_weight = kwargs.get("mid_lr_weight", None)
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)))
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
return network
# このメソッドは外部から呼び出される可能性を考慮しておく
# network_dim, network_alpha にはデフォルト値が入っている。
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
def get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
):
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
def parse_ints(s):
return [int(i) for i in s.split(",")]
def parse_floats(s):
return [float(i) for i in s.split(",")]
# block_dimsとblock_alphasをパースする。必ず値が入る
if block_dims is not None:
block_dims = parse_ints(block_dims)
assert (
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
block_alphas = parse_floats(block_alphas)
assert (
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
if conv_block_dims is not None:
conv_block_dims = parse_ints(conv_block_dims)
assert (
len(conv_block_dims) == num_total_blocks
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
if conv_block_alphas is not None:
conv_block_alphas = parse_floats(conv_block_alphas)
assert (
len(conv_block_alphas) == num_total_blocks
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
conv_block_dims = None
conv_block_alphas = None
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
def get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
) -> Tuple[List[float], List[float], List[float]]:
# パラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return None, None, None
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
def get_list(name_with_suffix) -> List[float]:
import math
tokens = name_with_suffix.split("+")
name = tokens[0]
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
elif name == "linear":
return [i / (max_len - 1) + base_lr for i in range(max_len)]
elif name == "reverse_linear":
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
return None
if type(down_lr_weight) == str:
down_lr_weight = get_list(down_lr_weight)
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
else:
print("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
else:
print("mid_lr_weight: 1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
else:
print("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
def remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
):
# set 0 to block dim without learning rate to remove the block
if down_lr_weight != None:
for i, lr in enumerate(down_lr_weight):
if lr == 0:
block_dims[i] = 0
if conv_block_dims is not None:
conv_block_dims[i] = 0
if mid_lr_weight != None:
if mid_lr_weight == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if up_lr_weight != None:
for i, lr in enumerate(up_lr_weight):
if lr == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 外部から呼び出す可能性を考慮しておく
def get_block_index(lora_name: str) -> int:
block_idx = -1 # invalid lora name
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
return block_idx
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
@@ -242,8 +451,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
network.weights_sd = weights_sd
return network
return network, weights_sd
class LoRANetwork(torch.nn.Module):
@@ -265,9 +473,22 @@ class LoRANetwork(torch.nn.Module):
alpha=1,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
block_alphas=None,
conv_block_dims=None,
conv_block_alphas=None,
modules_dim=None,
modules_alpha=None,
varbose=False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
1. lora_dimとalphaを指定
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
5. modules_dimとmodules_alphaを指定 (推論用)
"""
super().__init__()
self.multiplier = multiplier
@@ -278,62 +499,83 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
if self.apply_to_conv2d_3x3:
if self.conv_alpha is None:
self.conv_alpha = self.alpha
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
# TODO get block index here
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
dim = None
alpha = None
if modules_dim is not None:
if lora_name not in modules_dim:
continue # no LoRA module in this weights file
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
elif conv_block_dims is not None:
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else:
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
elif self.apply_to_conv2d_3x3:
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
else:
continue
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
continue
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
loras.append(lora)
return loras
return loras, skipped
self.text_encoder_loras = create_modules(
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None:
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
@@ -351,39 +593,7 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file)
else:
self.weights_sd = torch.load(file, map_location="cpu")
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
if self.weights_sd:
weights_has_text_encoder = weights_has_unet = False
for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if apply_text_encoder is None:
apply_text_encoder = weights_has_text_encoder
else:
assert (
apply_text_encoder == weights_has_text_encoder
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
if apply_unet is None:
apply_unet = weights_has_unet
else:
assert (
apply_unet == weights_has_unet
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
else:
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
else:
@@ -394,32 +604,14 @@ class LoRANetwork(torch.nn.Module):
else:
self.unet_loras = []
skipped = []
for lora in self.text_encoder_loras + self.unet_loras:
if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight
skipped.append(lora.lora_name)
continue
lora.apply_to()
self.add_module(lora.lora_name, lora)
if len(skipped) > 0:
print(
f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
if self.weights_sd:
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
info = self.load_state_dict(self.weights_sd, False)
print(f"weights are loaded: {info}")
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, dtype, device):
assert self.weights_sd is not None, "weights are not loaded"
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
apply_text_encoder = apply_unet = False
for key in self.weights_sd.keys():
for key in weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
@@ -437,9 +629,9 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in self.weights_sd.keys():
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
@@ -447,113 +639,18 @@ class LoRANetwork(torch.nn.Module):
# 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_block_lr_weight(
self,
up_lr_weight: Union[List[float], str] = None,
up_lr_weight: List[float] = None,
mid_lr_weight: float = None,
down_lr_weight: Union[List[float], str] = None,
zero_threshold: float = 0.0,
down_lr_weight: List[float] = None,
):
# パラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
def get_list(name_with_suffix) -> List[float]:
import math
tokens = name_with_suffix.split("+")
name = tokens[0]
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
elif name == "linear":
return [i / (max_len - 1) + base_lr for i in range(max_len)]
elif name == "reverse_linear":
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
return None
if type(down_lr_weight) == str:
down_lr_weight = get_list(down_lr_weight)
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
self.block_lr = True
if down_lr_weight != None:
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight)
else:
print("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", self.mid_lr_weight)
else:
print("mid_lr_weight: 1.0")
if up_lr_weight != None:
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight)
else:
print("up_lr_weight: all 1.0, すべて1.0")
return
def get_block_index(self, lora: LoRAModule) -> int:
block_idx = -1 # invalid lora name
m = RE_UPDOWN.search(lora.lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora.lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
return block_idx
self.block_lr = True
self.down_lr_weight = down_lr_weight
self.mid_lr_weight = mid_lr_weight
self.up_lr_weight = up_lr_weight
def get_lr_weight(self, lora: LoRAModule) -> float:
lr_weight = 1.0
block_idx = self.get_block_index(lora)
block_idx = get_block_index(lora.lora_name)
if block_idx < 0:
return lr_weight
@@ -590,7 +687,7 @@ class LoRANetwork(torch.nn.Module):
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras:
idx = self.get_block_index(lora)
idx = get_block_index(lora.lora_name)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)