diff --git a/README.md b/README.md index 7ef0947a..0590bfd0 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History -- 3 Apr. 2023, 2023/4/3: - - Add `--network_args` option to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for your great contribution! +- 4 Apr. 2023, 2023/4/4: + - Add options to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for the great contribution! - Specify the weights of 25 blocks for the full model. - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. - Specify the following arguments with `--network_args`. @@ -138,10 +138,19 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`. - `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight. - If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created. + - `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0. - - 階層別学習率を `train_network.py` で指定できるようにしました。u-haru 氏の多大な貢献に感謝します。 + - Add options to `train_network.py` to specify block dims (ranks) for variable rank. + - Specify 25 values ​​for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values ​​always. + - Specify the following arguments with `--network_args`. + - `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`. + - `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used. + - `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. + - `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used. + + - 階層別学習率を `train_network.py` で指定できるようになりました。u-haru 氏の多大な貢献に感謝します。 - フルモデルの25個のブロックの重みを指定できます。 - - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合は一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 + - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 -`--network_args` で以下の引数を指定してください。 - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 @@ -149,33 +158,30 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 - + - `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 -- 1 Apr. 2023, 2023/4/1: - - Fix an issue that `merge_lora.py` does not work with the latest version. - - Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights. - - 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。 - - `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。 -- 31 Mar. 2023, 2023/3/31: - - Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`. - - Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354) - - `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。 - - `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354) -- 30 Mar. 2023, 2023/3/30: - - Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev! - - See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details. - - Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported. - - Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option. - - Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec! - - Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option. + - 階層別dim (rank)を `train_network.py` で指定できるようになりました。 + - フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 + - `--network_args` で以下の引数を指定してください。 + - `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 + - `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。 + - `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。 + - `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。 + + - 階層別学習率コマンドライン指定例 / Examples of block learning rate command line specification: + + ` --network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"` + + ` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"` + + - 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification: + + ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"` + + ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` + + ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` - - [P+](https://prompt-plus.github.io/) の学習に対応しました。jakaline-dev氏に感謝します。 - - 詳細は [#327](https://github.com/kohya-ss/sd-scripts/pull/327) をご参照ください。 - - 学習には `train_textual_inversion_XTI.py` を使用します。使用法は `train_textual_inversion.py` とほぼ同じです。た - だし学習中のサンプル生成には対応していません。 - - 画像生成には `gen_img_diffusers.py` を使用してください(Web UIは対応していないと思われます)。`--XTI_embeddings` オプションで学習したembeddingを指定してください。 - - `train_network.py` で起動時のRAM使用量を削減しました。[#332](https://github.com/kohya-ss/sd-scripts/pull/332) guaneec氏に感謝します。 - - `gen_img_diffusers.py` でLoRAの事前マージに対応しました。`--network_merge` オプションを指定してください。なおプロンプトオプションの `--am` は使用できなくなります。 ## Sample image generation during training A prompt file might look like this, for example diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 225de33c..a0469766 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2275,7 +2275,7 @@ def main(args): if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") - network = imported_module.create_network_from_weights( + network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, **net_kwargs ) else: @@ -2285,6 +2285,8 @@ def main(args): if not args.network_merge: network.apply_to(text_encoder, unet) + info = network.load_state_dict(weights_sd, False) + print(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) @@ -2292,7 +2294,7 @@ def main(args): networks.append(network) else: - network.merge_to(text_encoder, unet, dtype, device) + network.merge_to(text_encoder, unet, weights_sd, dtype, device) else: networks = [] diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 9aa28485..f001e7eb 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -145,8 +145,8 @@ def svd(args): lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) - lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") diff --git a/networks/lora.py b/networks/lora.py index 27335efe..c5372688 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -143,6 +143,8 @@ class LoRAModule(torch.nn.Module): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 # extract dim/alpha for conv2d, and block dim conv_dim = kwargs.get("conv_dim", None) @@ -154,34 +156,50 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un else: conv_alpha = float(conv_alpha) - """ - block_dims = kwargs.get("block_dims") - block_alphas = None + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + down_lr_weight = kwargs.get("down_lr_weight", None) + mid_lr_weight = kwargs.get("mid_lr_weight", None) + up_lr_weight = kwargs.get("up_lr_weight", None) + + # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする + if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + ) - if block_dims is not None: - block_dims = [int(d) for d in block_dims.split(',')] - assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" - block_alphas = kwargs.get("block_alphas") - if block_alphas is None: - block_alphas = [1] * len(block_dims) else: - block_alphas = [int(a) for a in block_alphas(',')] - assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - - conv_block_dims = kwargs.get("conv_block_dims") - conv_block_alphas = None - - if conv_block_dims is not None: - conv_block_dims = [int(d) for d in conv_block_dims.split(',')] - assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" - conv_block_alphas = kwargs.get("conv_block_alphas") - if conv_block_alphas is None: - conv_block_alphas = [1] * len(conv_block_dims) - else: - conv_block_alphas = [int(a) for a in conv_block_alphas(',')] - assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - """ + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoder, unet, @@ -190,28 +208,219 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, ) - # if some parameters are not set, use zero - up_lr_weight = kwargs.get("up_lr_weight", None) - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight = kwargs.get("down_lr_weight", None) - if down_lr_weight is not None: - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - mid_lr_weight = kwargs.get("mid_lr_weight", None) - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) return network +# このメソッドは外部から呼び出される可能性を考慮しておく +# network_dim, network_alpha にはデフォルト値が入っている。 +# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている +# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている +def get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + # block_dimsとblock_alphasをパースする。必ず値が入る + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert ( + len(block_dims) == num_total_blocks + ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + else: + print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" + else: + print( + f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" + else: + if conv_alpha is None: + conv_alpha = 1.0 + print( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + print( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +def get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold +) -> Tuple[List[float], List[float], List[float]]: + # パラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None, None, None + + max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) + base_lr for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + elif name == "zeros": + return [0.0 + base_lr] * max_len + else: + print( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: + down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len: + up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + print("apply block learning rate / 階層別学習率を適用します。") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + else: + print("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: + mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:", mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + else: + print("up_lr_weight: all 1.0, すべて1.0") + + return down_lr_weight, mid_lr_weight, up_lr_weight + + +# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく +def remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight +): + # set 0 to block dim without learning rate to remove the block + if down_lr_weight != None: + for i, lr in enumerate(down_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + if mid_lr_weight != None: + if mid_lr_weight == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if up_lr_weight != None: + for i, lr in enumerate(up_lr_weight): + if lr == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str) -> int: + block_idx = -1 # invalid lora name + + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + +# Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": @@ -242,8 +451,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh modules_alpha = modules_dim[key] network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) - network.weights_sd = weights_sd - return network + return network, weights_sd class LoRANetwork(torch.nn.Module): @@ -265,9 +473,22 @@ class LoRANetwork(torch.nn.Module): alpha=1, conv_lora_dim=None, conv_alpha=None, + block_dims=None, + block_alphas=None, + conv_block_dims=None, + conv_block_alphas=None, modules_dim=None, modules_alpha=None, + varbose=False, ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ super().__init__() self.multiplier = multiplier @@ -278,62 +499,83 @@ class LoRANetwork(torch.nn.Module): if modules_dim is not None: print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") else: print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - - self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None - if self.apply_to_conv2d_3x3: - if self.conv_alpha is None: - self.conv_alpha = self.alpha - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if self.conv_lora_dim is not None: + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER loras = [] + skipped = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: - # TODO get block index here for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + if is_linear or is_conv2d: lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + dim = None + alpha = None if modules_dim is not None: - if lora_name not in modules_dim: - continue # no LoRA module in this weights file - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] else: if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha - elif self.apply_to_conv2d_3x3: + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha - else: - continue + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) loras.append(lora) - return loras + return loras, skipped - self.text_encoder_loras = create_modules( - LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - ) + self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE - if modules_dim is not None or self.conv_lora_dim is not None: + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) + self.unet_loras, skipped_un = create_modules(True, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - self.weights_sd = None + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -351,39 +593,7 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file, safe_open - - self.weights_sd = load_file(file) - else: - self.weights_sd = torch.load(file, map_location="cpu") - - def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): - if self.weights_sd: - weights_has_text_encoder = weights_has_unet = False - for key in self.weights_sd.keys(): - if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): - weights_has_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): - weights_has_unet = True - - if apply_text_encoder is None: - apply_text_encoder = weights_has_text_encoder - else: - assert ( - apply_text_encoder == weights_has_text_encoder - ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" - - if apply_unet is None: - apply_unet = weights_has_unet - else: - assert ( - apply_unet == weights_has_unet - ), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" - else: - assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" - + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: print("enable LoRA for text encoder") else: @@ -394,32 +604,14 @@ class LoRANetwork(torch.nn.Module): else: self.unet_loras = [] - skipped = [] for lora in self.text_encoder_loras + self.unet_loras: - if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight - skipped.append(lora.lora_name) - continue lora.apply_to() self.add_module(lora.lora_name, lora) - if len(skipped) > 0: - print( - f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" - ) - for name in skipped: - print(f"\t{name}") - - if self.weights_sd: - # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) - info = self.load_state_dict(self.weights_sd, False) - print(f"weights are loaded: {info}") - # TODO refactor to common function with apply_to - def merge_to(self, text_encoder, unet, dtype, device): - assert self.weights_sd is not None, "weights are not loaded" - + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_text_encoder = apply_unet = False - for key in self.weights_sd.keys(): + for key in weights_sd.keys(): if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): apply_text_encoder = True elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): @@ -437,9 +629,9 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: sd_for_lora = {} - for key in self.weights_sd.keys(): + for key in weights_sd.keys(): if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) print(f"weights are merged") @@ -447,113 +639,18 @@ class LoRANetwork(torch.nn.Module): # 層別学習率用に層ごとの学習率に対する倍率を定義する def set_block_lr_weight( self, - up_lr_weight: Union[List[float], str] = None, + up_lr_weight: List[float] = None, mid_lr_weight: float = None, - down_lr_weight: Union[List[float], str] = None, - zero_threshold: float = 0.0, + down_lr_weight: List[float] = None, ): - # パラメータ未指定時は何もせず、今までと同じ動作とする - if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return - - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 - - def get_list(name_with_suffix) -> List[float]: - import math - - tokens = name_with_suffix.split("+") - name = tokens[0] - base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 - - if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] - elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] - elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] - elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] - elif name == "zeros": - return [0.0 + base_lr] * max_len - else: - print( - "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" - % (name) - ) - return None - - if type(down_lr_weight) == str: - down_lr_weight = get_list(down_lr_weight) - if type(up_lr_weight) == str: - up_lr_weight = get_list(up_lr_weight) - - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] - - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) - - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) - - if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") - self.block_lr = True - - if down_lr_weight != None: - self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight) - else: - print("down_lr_weight: all 1.0, すべて1.0") - - if mid_lr_weight != None: - self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", self.mid_lr_weight) - else: - print("mid_lr_weight: 1.0") - - if up_lr_weight != None: - self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight) - else: - print("up_lr_weight: all 1.0, すべて1.0") - - return - - def get_block_index(self, lora: LoRAModule) -> int: - block_idx = -1 # invalid lora name - - m = RE_UPDOWN.search(lora.lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora.lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 - - return block_idx + self.block_lr = True + self.down_lr_weight = down_lr_weight + self.mid_lr_weight = mid_lr_weight + self.up_lr_weight = up_lr_weight def get_lr_weight(self, lora: LoRAModule) -> float: lr_weight = 1.0 - block_idx = self.get_block_index(lora) + block_idx = get_block_index(lora.lora_name) if block_idx < 0: return lr_weight @@ -590,7 +687,7 @@ class LoRANetwork(torch.nn.Module): # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 block_idx_to_lora = {} for lora in self.unet_loras: - idx = self.get_block_index(lora) + idx = get_block_index(lora.lora_name) if idx not in block_idx_to_lora: block_idx_to_lora[idx] = [] block_idx_to_lora[idx].append(lora)