Merge pull request #1331 from kohya-ss/lora-plus

Lora plus
This commit is contained in:
Kohya S
2024-05-12 17:04:58 +09:00
committed by GitHub
5 changed files with 385 additions and 191 deletions

View File

@@ -154,7 +154,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
- Fixed some bugs when using DeepSpeed. Related [#1247]
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
- `network_module` `networks.lora` and `networks.dylora` are available.
- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
- Specify the learning rate and dim (rank) for each block.
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
@@ -171,7 +182,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
- 仕組みFused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247]
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
- LoRA の UP 側LoRA-Bの学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
- `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
- `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
- `network_module` の `networks.lora` および `networks.dylora` で使用可能です。
- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
- ブロックごとに学習率および dim (rank) を指定することができます。
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
### Apr 7, 2024 / 2024-04-07: v0.8.7

View File

@@ -181,16 +181,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
SDXLは現在サポートしていません。
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
SDXL では down/up 9 個、middle 3 個の値を指定してください。
`--network_args` で以下の引数を指定してください。
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個SDXL では 9 個)の数値を指定します。
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定しますサインカーブで重みを指定します。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します0.25~1.25になります)。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定しますSDXL の場合は 3 個)
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
@@ -215,6 +215,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。
対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。
`--network_args` で以下の引数を指定してください。
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。

View File

@@ -18,10 +18,13 @@ from transformers import CLIPTextModel
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -195,7 +198,7 @@ def create_network(
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
if unit is not None:
unit = int(unit)
else:
@@ -211,6 +214,16 @@ def create_network(
unit=unit,
varbose=True,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
@@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module):
self.alpha = alpha
self.apply_to_conv = apply_to_conv
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info("create LoRA network from weights")
else:
@@ -320,9 +337,9 @@ class DyLoRANetwork(torch.nn.Module):
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
@@ -331,7 +348,7 @@ class DyLoRANetwork(torch.nn.Module):
else:
index = None
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
@@ -346,6 +363,11 @@ class DyLoRANetwork(torch.nn.Module):
self.unet_loras = create_modules(True, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
@@ -406,27 +428,53 @@ class DyLoRANetwork(torch.nn.Module):
logger.info(f"weights are merged")
"""
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue
params.append(param_data)
return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
)
all_params.extend(params)
return all_params

View File

@@ -12,6 +12,7 @@ import numpy as np
import torch
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
setup_logging()
import logging
@@ -385,14 +386,14 @@ class LoRAInfModule(LoRAModule):
return out
def parse_block_lr_kwargs(nw_kwargs):
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None, None, None
return None
# extract learning rate weight for each block
if down_lr_weight is not None:
@@ -401,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs):
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
return get_block_lr_weight(
is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
)
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(
multiplier: float,
@@ -424,6 +423,9 @@ def create_network(
neuron_dropout: Optional[float] = None,
**kwargs,
):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -441,21 +443,21 @@ def create_network(
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
if block_dims is not None or block_lr_weight is not None:
block_alphas = kwargs.get("block_alphas", None)
conv_block_dims = kwargs.get("conv_block_dims", None)
conv_block_alphas = kwargs.get("conv_block_alphas", None)
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
)
else:
@@ -488,10 +490,20 @@ def create_network(
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
is_sdxl=is_sdxl,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network
@@ -501,9 +513,13 @@ def create_network(
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
def get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
):
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
if not is_sdxl:
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
else:
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
def parse_ints(s):
return [int(i) for i in s.split(",")]
@@ -514,9 +530,10 @@ def get_block_dims_and_alphas(
# block_dimsとblock_alphasをパースする。必ず値が入る
if block_dims is not None:
block_dims = parse_ints(block_dims)
assert (
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
assert len(block_dims) == num_total_blocks, (
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
+ f" / block_dims{num_total_blocks}個指定してください(指定された個数: {len(block_dims)}"
)
else:
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
@@ -567,15 +584,25 @@ def get_block_dims_and_alphas(
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
# 戻り値は block ごとの倍率のリスト
def get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
) -> Tuple[List[float], List[float], List[float]]:
is_sdxl,
down_lr_weight: Union[str, List[float]],
mid_lr_weight: List[float],
up_lr_weight: Union[str, List[float]],
zero_threshold: float,
) -> Optional[List[float]]:
# パラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return None, None, None
return None
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
if not is_sdxl:
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
else:
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
def get_list(name_with_suffix) -> List[float]:
import math
@@ -585,15 +612,18 @@ def get_block_lr_weight(
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
return [
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
for i in reversed(range(max_len_for_down_or_up))
]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
elif name == "linear":
return [i / (max_len - 1) + base_lr for i in range(max_len)]
return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
elif name == "reverse_linear":
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
elif name == "zeros":
return [0.0 + base_lr] * max_len
return [0.0 + base_lr] * max_len_for_down_or_up
else:
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
@@ -606,20 +636,36 @@ def get_block_lr_weight(
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
):
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightがすぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
logger.warning("mid_weightがすぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
):
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
logger.warning(
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
)
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
logger.info("apply block learning rate / 階層別学習率を適用します。")
@@ -627,72 +673,84 @@ def get_block_lr_weight(
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
down_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
logger.info("mid_lr_weight: 1.0")
mid_lr_weight = [1.0] * max_len_for_mid
logger.info("mid_lr_weight: all 1.0, すべて1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
up_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
if is_sdxl:
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
), f"lr_weight length is invalid: {len(lr_weight)}"
return lr_weight
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
def remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
):
# set 0 to block dim without learning rate to remove the block
if down_lr_weight != None:
for i, lr in enumerate(down_lr_weight):
if block_lr_weight is not None:
for i, lr in enumerate(block_lr_weight):
if lr == 0:
block_dims[i] = 0
if conv_block_dims is not None:
conv_block_dims[i] = 0
if mid_lr_weight != None:
if mid_lr_weight == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if up_lr_weight != None:
for i, lr in enumerate(up_lr_weight):
if lr == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 外部から呼び出す可能性を考慮しておく
def get_block_index(lora_name: str) -> int:
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
block_idx = -1 # invalid lora name
if not is_sdxl:
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
else:
# copy from sdxl_train
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
block_idx = 0 # 0
elif name.startswith("input_blocks_"): # 1-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 10-12
block_idx = 10 + int(name.split("_")[2])
elif name.startswith("output_blocks_"): # 13-21
block_idx = 13 + int(name.split("_")[2])
elif name.startswith("out_"): # 22, out, no LoRA
block_idx = 22
return block_idx
@@ -734,15 +792,18 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
)
# block lr
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
block_lr_weight = parse_block_lr_kwargs(kwargs)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
NUM_OF_MID_BLOCKS = 1
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
SDXL_NUM_OF_MID_BLOCKS = 3
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
@@ -774,6 +835,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -794,6 +856,10 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
@@ -855,7 +921,7 @@ class LoRANetwork(torch.nn.Module):
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
block_idx = get_block_index(lora_name, is_sdxl)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
@@ -925,9 +991,7 @@ class LoRANetwork(torch.nn.Module):
for name in skipped:
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
self.mid_lr_weight: float = None
self.block_lr_weight = None
self.block_lr = False
# assertion
@@ -958,12 +1022,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
@@ -1004,81 +1068,114 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
self,
up_lr_weight: List[float] = None,
mid_lr_weight: float = None,
down_lr_weight: List[float] = None,
):
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
self.block_lr = True
self.down_lr_weight = down_lr_weight
self.mid_lr_weight = mid_lr_weight
self.up_lr_weight = up_lr_weight
self.block_lr_weight = block_lr_weight
def get_lr_weight(self, lora: LoRAModule) -> float:
lr_weight = 1.0
block_idx = get_block_index(lora.lora_name)
if block_idx < 0:
return lr_weight
def get_lr_weight(self, block_idx: int) -> float:
if not self.block_lr or self.block_lr_weight is None:
return 1.0
return self.block_lr_weight[block_idx]
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
if self.down_lr_weight != None:
lr_weight = self.down_lr_weight[block_idx]
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
if self.mid_lr_weight != None:
lr_weight = self.mid_lr_weight
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
if self.up_lr_weight != None:
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
return lr_weight
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
# if (
# self.loraplus_lr_ratio is not None
# or self.loraplus_text_encoder_lr_ratio is not None
# or self.loraplus_unet_lr_ratio is not None
# ):
# assert (
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
def enumerate_params(loras):
params = []
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
return params
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params, descriptions = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
if self.block_lr:
is_sdxl = False
for lora in self.unet_loras:
if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
is_sdxl = True
break
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras:
idx = get_block_index(lora.lora_name)
idx = get_block_index(lora.lora_name, is_sdxl)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params, descriptions = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported

View File

@@ -53,7 +53,15 @@ class NetworkTrainer:
# TODO 他のスクリプトと共通化する
def generate_step_logs(
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
self,
args: argparse.Namespace,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -63,34 +71,26 @@ class NetworkTrainer:
logs["max_norm/max_key_norm"] = maximum_norm
lrs = lr_scheduler.get_last_lr()
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
if args.network_train_unet_only:
logs["lr/unet"] = float(lrs[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lrs[0])
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
lr_desc = lr_descriptions[i]
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
idx = i - (0 if args.network_train_unet_only else -1)
if idx == -1:
lr_desc = "textencoder"
else:
if len(lrs) > 2:
lr_desc = f"group{idx}"
else:
lr_desc = "unet"
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
logs[f"lr/{lr_desc}"] = lr
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
# tracking d*lr value
logs[f"lr/d*lr/{lr_desc}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
else:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs
@@ -323,6 +323,7 @@ class NetworkTrainer:
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
@@ -338,12 +339,30 @@ class NetworkTrainer:
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
)
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
if type(results) is tuple:
trainable_params = results[0]
lr_descriptions = results[1]
else:
trainable_params = results
lr_descriptions = None
except TypeError as e:
# logger.warning(f"{e}")
# accelerator.print(
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
# )
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
lr_descriptions = None
# if len(trainable_params) == 0:
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
# for params in trainable_params:
# for k, v in params.items():
# if type(v) == float:
# pass
# else:
# v = len(v)
# accelerator.print(f"trainable_params: {k} = {v}")
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -952,7 +971,9 @@ class NetworkTrainer:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
@@ -1103,6 +1124,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
return parser