change 'stratify' to 'block', add en message

This commit is contained in:
Kohya S
2023-04-02 16:10:09 +09:00
parent 36c8a4aee7
commit 97e65bf93f

View File

@@ -12,7 +12,8 @@ import re
from library import train_util
RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_')
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class LoRAModule(torch.nn.Module):
"""
@@ -191,18 +192,22 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_alpha=conv_alpha,
)
up_lr_weight=None
if 'up_lr_weight' in kwargs:
up_lr_weight = kwargs.get('up_lr_weight',None)
# if some parameters are not set, use zero
up_lr_weight = kwargs.get("up_lr_weight", None)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s]
down_lr_weight=None
if 'down_lr_weight' in kwargs:
down_lr_weight = kwargs.get('down_lr_weight',None)
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight = kwargs.get("down_lr_weight", None)
if down_lr_weight is not None:
if "," in down_lr_weight:
down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s]
mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None
network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
mid_lr_weight = kwargs.get("mid_lr_weight", None)
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)))
return network
@@ -328,17 +333,17 @@ class LoRANetwork(torch.nn.Module):
self.weights_sd = None
self.up_lr_weight: list[float] = None
self.down_lr_weight: list[float] = None
self.mid_lr_weight: float = None
self.block_lr = False
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
self.up_lr_weight:list[float] = None
self.down_lr_weight:list[float] = None
self.mid_lr_weight:float = None
self.stratified_lr = False
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
@@ -389,13 +394,16 @@ class LoRANetwork(torch.nn.Module):
skipped = []
for lora in self.text_encoder_loras + self.unet_loras:
if self.get_stratified_lr_weight(lora) == 0:
if self.block_lr and self.get_block_lr_weight(lora) == 0:
skipped.append(lora.lora_name)
continue
lora.apply_to()
self.add_module(lora.lora_name, lora)
if len(skipped)>0:
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
if len(skipped) > 0:
print(
f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
@@ -431,76 +439,109 @@ class LoRANetwork(torch.nn.Module):
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
max_len=12 # フルモデル相当でのup,downの層の数
def get_list(name) -> list[float]:
import math
if name=="cosine":
return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))]
elif name=="sine":
return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)]
elif name=="linear":
return [i/(max_len-1) for i in range(max_len)]
elif name=="reverse_linear":
return [i/(max_len-1) for i in reversed(range(max_len))]
elif name=="zeros":
return [0.0] * max_len
else:
print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
return None
if type(down_lr_weight)==str:
down_lr_weight=get_list(down_lr_weight)
if type(up_lr_weight)==str:
up_lr_weight=get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len):
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len)
if (up_lr_weight != None and len(up_lr_weight)<max_len) or (down_lr_weight != None and len(down_lr_weight)<max_len):
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"%max_len)
if down_lr_weight != None and len(down_lr_weight)<max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight)<max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("層別学習率を適用します。")
self.stratified_lr = True
if (down_lr_weight != None):
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
if (mid_lr_weight != None):
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:",self.mid_lr_weight)
if (up_lr_weight != None):
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]]
print("up_lr_weight(深い層->浅い層):",self.up_lr_weight)
def set_block_lr_weight(
self,
up_lr_weight: list[float] | str = None,
mid_lr_weight: float = None,
down_lr_weight: list[float] | str = None,
zero_threshold: float = 0.0,
):
# バラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
max_len = 12 # フルモデル相当でのup,downの層の数
def get_list(name) -> list[float]:
import math
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)]
elif name == "linear":
return [i / (max_len - 1) for i in range(max_len)]
elif name == "reverse_linear":
return [i / (max_len - 1) for i in reversed(range(max_len))]
elif name == "zeros":
return [0.0] * max_len
else:
print(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
return None
if type(down_lr_weight) == str:
down_lr_weight = get_list(down_lr_weight)
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
self.block_lr = True
if down_lr_weight != None:
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight)
else:
print("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", self.mid_lr_weight)
else:
print("mid_lr_weight: 1.0")
if up_lr_weight != None:
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight)
else:
print("up_lr_weight: all 1.0, すべて1.0")
return
def get_block_lr_weight(self, lora: LoRAModule) -> float:
m = RE_UPDOWN.search(lora.lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2]=="resnets":
idx=3*i + j
elif g[2]=="attentions":
idx=3*i + j
elif g[2]=="upsamplers" or g[2]=="downsamplers":
idx=3*i + 2
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if (g[0]=="down") and (self.down_lr_weight != None):
return self.down_lr_weight[idx+1]
elif (g[0]=="up") and (self.up_lr_weight != None):
if (g[0] == "down") and (self.down_lr_weight != None):
return self.down_lr_weight[idx + 1]
if (g[0] == "up") and (self.up_lr_weight != None):
return self.up_lr_weight[idx]
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
return self.mid_lr_weight
return 1
def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
@@ -508,29 +549,29 @@ class LoRANetwork(torch.nn.Module):
params = []
for lora in self.text_encoder_loras:
params.extend(lora.parameters())
param_data = {'params': params}
param_data = {"params": params}
if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
if self.unet_loras:
if self.stratified_lr:
if self.block_lr:
for lora in self.unet_loras:
param_data = {'params': lora.parameters()}
param_data = {"params": lora.parameters()}
if unet_lr is not None:
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
param_data["lr"] = unet_lr * self.get_block_lr_weight(lora)
elif default_lr is not None:
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
if ('lr' in param_data) and (param_data['lr']==0):
param_data["lr"] = default_lr * self.get_block_lr_weight(lora)
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
else:
params = []
for lora in self.unet_loras:
params.extend(lora.parameters())
param_data = {'params': params}
param_data = {"params": params}
if unet_lr is not None:
param_data['lr'] = unet_lr
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params