mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add block dim(rank) feature
This commit is contained in:
@@ -145,8 +145,8 @@ def svd(args):
|
||||
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
|
||||
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
543
networks/lora.py
543
networks/lora.py
@@ -143,6 +143,8 @@ class LoRAModule(torch.nn.Module):
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = kwargs.get("conv_dim", None)
|
||||
@@ -154,34 +156,50 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
"""
|
||||
block_dims = kwargs.get("block_dims")
|
||||
block_alphas = None
|
||||
# block dim/alpha/lr
|
||||
block_dims = kwargs.get("block_dims", None)
|
||||
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||
|
||||
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
||||
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
||||
block_alphas = kwargs.get("block_alphas", None)
|
||||
conv_block_dims = kwargs.get("conv_block_dims", None)
|
||||
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
||||
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
)
|
||||
|
||||
# extract learning rate weight for each block
|
||||
if down_lr_weight is not None:
|
||||
# if some parameters are not set, use zero
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
# remove block dim/alpha without learning rate
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
)
|
||||
|
||||
if block_dims is not None:
|
||||
block_dims = [int(d) for d in block_dims.split(',')]
|
||||
assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||
block_alphas = kwargs.get("block_alphas")
|
||||
if block_alphas is None:
|
||||
block_alphas = [1] * len(block_dims)
|
||||
else:
|
||||
block_alphas = [int(a) for a in block_alphas(',')]
|
||||
assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||
|
||||
conv_block_dims = kwargs.get("conv_block_dims")
|
||||
conv_block_alphas = None
|
||||
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims = [int(d) for d in conv_block_dims.split(',')]
|
||||
assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}"
|
||||
conv_block_alphas = kwargs.get("conv_block_alphas")
|
||||
if conv_block_alphas is None:
|
||||
conv_block_alphas = [1] * len(conv_block_dims)
|
||||
else:
|
||||
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
||||
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||
"""
|
||||
block_alphas = None
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
@@ -190,28 +208,219 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
alpha=network_alpha,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
block_dims=block_dims,
|
||||
block_alphas=block_alphas,
|
||||
conv_block_dims=conv_block_dims,
|
||||
conv_block_alphas=conv_block_alphas,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
# if some parameters are not set, use zero
|
||||
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||
if down_lr_weight is not None:
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)))
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
# このメソッドは外部から呼び出される可能性を考慮しておく
|
||||
# network_dim, network_alpha にはデフォルト値が入っている。
|
||||
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
||||
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
||||
def get_block_dims_and_alphas(
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
):
|
||||
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
||||
|
||||
def parse_ints(s):
|
||||
return [int(i) for i in s.split(",")]
|
||||
|
||||
def parse_floats(s):
|
||||
return [float(i) for i in s.split(",")]
|
||||
|
||||
# block_dimsとblock_alphasをパースする。必ず値が入る
|
||||
if block_dims is not None:
|
||||
block_dims = parse_ints(block_dims)
|
||||
assert (
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
block_alphas = parse_floats(block_alphas)
|
||||
assert (
|
||||
len(block_alphas) == num_total_blocks
|
||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(
|
||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||
)
|
||||
block_alphas = [network_alpha] * num_total_blocks
|
||||
|
||||
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims = parse_ints(conv_block_dims)
|
||||
assert (
|
||||
len(conv_block_dims) == num_total_blocks
|
||||
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
||||
|
||||
if conv_block_alphas is not None:
|
||||
conv_block_alphas = parse_floats(conv_block_alphas)
|
||||
assert (
|
||||
len(conv_block_alphas) == num_total_blocks
|
||||
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
print(
|
||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||
)
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
if conv_dim is not None:
|
||||
print(
|
||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||
)
|
||||
conv_block_dims = [conv_dim] * num_total_blocks
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
||||
def get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return None, None, None
|
||||
|
||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||
|
||||
def get_list(name_with_suffix) -> List[float]:
|
||||
import math
|
||||
|
||||
tokens = name_with_suffix.split("+")
|
||||
name = tokens[0]
|
||||
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "sine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
||||
elif name == "linear":
|
||||
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
||||
elif name == "reverse_linear":
|
||||
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
return None
|
||||
|
||||
if type(down_lr_weight) == str:
|
||||
down_lr_weight = get_list(down_lr_weight)
|
||||
if type(up_lr_weight) == str:
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
if down_lr_weight != None:
|
||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||
else:
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
print("mid_lr_weight:", mid_lr_weight)
|
||||
else:
|
||||
print("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||
else:
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
||||
def remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
):
|
||||
# set 0 to block dim without learning rate to remove the block
|
||||
if down_lr_weight != None:
|
||||
for i, lr in enumerate(down_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[i] = 0
|
||||
if mid_lr_weight != None:
|
||||
if mid_lr_weight == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
||||
if up_lr_weight != None:
|
||||
for i, lr in enumerate(up_lr_weight):
|
||||
if lr == 0:
|
||||
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
if conv_block_dims is not None:
|
||||
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
||||
|
||||
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
||||
|
||||
|
||||
# 外部から呼び出す可能性を考慮しておく
|
||||
def get_block_index(lora_name: str) -> int:
|
||||
block_idx = -1 # invalid lora name
|
||||
|
||||
m = RE_UPDOWN.search(lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
|
||||
elif "mid_block_" in lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
|
||||
return block_idx
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
@@ -242,8 +451,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
modules_alpha = modules_dim[key]
|
||||
|
||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
||||
network.weights_sd = weights_sd
|
||||
return network
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
@@ -265,9 +473,22 @@ class LoRANetwork(torch.nn.Module):
|
||||
alpha=1,
|
||||
conv_lora_dim=None,
|
||||
conv_alpha=None,
|
||||
block_dims=None,
|
||||
block_alphas=None,
|
||||
conv_block_dims=None,
|
||||
conv_block_alphas=None,
|
||||
modules_dim=None,
|
||||
modules_alpha=None,
|
||||
varbose=False,
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
1. lora_dimとalphaを指定
|
||||
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
||||
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
||||
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
||||
5. modules_dimとmodules_alphaを指定 (推論用)
|
||||
"""
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
@@ -278,62 +499,83 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
|
||||
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
|
||||
if self.apply_to_conv2d_3x3:
|
||||
if self.conv_alpha is None:
|
||||
self.conv_alpha = self.alpha
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
# TODO get block index here
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
if modules_dim is not None:
|
||||
if lora_name not in modules_dim:
|
||||
continue # no LoRA module in this weights file
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
elif is_unet and block_dims is not None:
|
||||
block_idx = get_block_index(lora_name)
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = block_dims[block_idx]
|
||||
alpha = block_alphas[block_idx]
|
||||
elif conv_block_dims is not None:
|
||||
dim = conv_block_dims[block_idx]
|
||||
alpha = conv_block_alphas[block_idx]
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
elif self.apply_to_conv2d_3x3:
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
else:
|
||||
continue
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
return loras, skipped
|
||||
|
||||
self.text_encoder_loras = create_modules(
|
||||
LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
)
|
||||
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.conv_lora_dim is not None:
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
self.weights_sd = None
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
print(
|
||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
print(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
@@ -351,39 +593,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
self.weights_sd = load_file(file)
|
||||
else:
|
||||
self.weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
||||
if self.weights_sd:
|
||||
weights_has_text_encoder = weights_has_unet = False
|
||||
for key in self.weights_sd.keys():
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
weights_has_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||
weights_has_unet = True
|
||||
|
||||
if apply_text_encoder is None:
|
||||
apply_text_encoder = weights_has_text_encoder
|
||||
else:
|
||||
assert (
|
||||
apply_text_encoder == weights_has_text_encoder
|
||||
), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
||||
|
||||
if apply_unet is None:
|
||||
apply_unet = weights_has_unet
|
||||
else:
|
||||
assert (
|
||||
apply_unet == weights_has_unet
|
||||
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
||||
else:
|
||||
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
@@ -394,32 +604,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
skipped = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight
|
||||
skipped.append(lora.lora_name)
|
||||
continue
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
if len(skipped) > 0:
|
||||
print(
|
||||
f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
print(f"\t{name}")
|
||||
|
||||
if self.weights_sd:
|
||||
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
||||
info = self.load_state_dict(self.weights_sd, False)
|
||||
print(f"weights are loaded: {info}")
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, dtype, device):
|
||||
assert self.weights_sd is not None, "weights are not loaded"
|
||||
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in self.weights_sd.keys():
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
||||
@@ -437,9 +629,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in self.weights_sd.keys():
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
@@ -447,113 +639,18 @@ class LoRANetwork(torch.nn.Module):
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||
def set_block_lr_weight(
|
||||
self,
|
||||
up_lr_weight: Union[List[float], str] = None,
|
||||
up_lr_weight: List[float] = None,
|
||||
mid_lr_weight: float = None,
|
||||
down_lr_weight: Union[List[float], str] = None,
|
||||
zero_threshold: float = 0.0,
|
||||
down_lr_weight: List[float] = None,
|
||||
):
|
||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return
|
||||
|
||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||
|
||||
def get_list(name_with_suffix) -> List[float]:
|
||||
import math
|
||||
|
||||
tokens = name_with_suffix.split("+")
|
||||
name = tokens[0]
|
||||
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "sine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
||||
elif name == "linear":
|
||||
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
||||
elif name == "reverse_linear":
|
||||
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
return None
|
||||
|
||||
if type(down_lr_weight) == str:
|
||||
down_lr_weight = get_list(down_lr_weight)
|
||||
if type(up_lr_weight) == str:
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
self.block_lr = True
|
||||
|
||||
if down_lr_weight != None:
|
||||
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight)
|
||||
else:
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
print("mid_lr_weight:", self.mid_lr_weight)
|
||||
else:
|
||||
print("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight)
|
||||
else:
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return
|
||||
|
||||
def get_block_index(self, lora: LoRAModule) -> int:
|
||||
block_idx = -1 # invalid lora name
|
||||
|
||||
m = RE_UPDOWN.search(lora.lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
i = int(g[1])
|
||||
j = int(g[3])
|
||||
if g[2] == "resnets":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "attentions":
|
||||
idx = 3 * i + j
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
|
||||
elif "mid_block_" in lora.lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
|
||||
return block_idx
|
||||
self.block_lr = True
|
||||
self.down_lr_weight = down_lr_weight
|
||||
self.mid_lr_weight = mid_lr_weight
|
||||
self.up_lr_weight = up_lr_weight
|
||||
|
||||
def get_lr_weight(self, lora: LoRAModule) -> float:
|
||||
lr_weight = 1.0
|
||||
block_idx = self.get_block_index(lora)
|
||||
block_idx = get_block_index(lora.lora_name)
|
||||
if block_idx < 0:
|
||||
return lr_weight
|
||||
|
||||
@@ -590,7 +687,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
block_idx_to_lora = {}
|
||||
for lora in self.unet_loras:
|
||||
idx = self.get_block_index(lora)
|
||||
idx = get_block_index(lora.lora_name)
|
||||
if idx not in block_idx_to_lora:
|
||||
block_idx_to_lora[idx] = []
|
||||
block_idx_to_lora[idx].append(lora)
|
||||
|
||||
Reference in New Issue
Block a user