引数名に_lrを追加

This commit is contained in:
u-haru
2023-03-31 01:40:29 +09:00
parent dade23a414
commit 1b75dbd4f2

View File

@@ -191,18 +191,18 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_alpha=conv_alpha,
)
up_weight=None
if 'up_weight' in kwargs:
up_weight = kwargs.get('up_weight',None)
if "," in up_weight:
up_weight = [float(s) for s in up_weight.split(",") if s]
down_weight=None
if 'down_weight' in kwargs:
down_weight = kwargs.get('down_weight',None)
if "," in down_weight:
down_weight = [float(s) for s in down_weight.split(",") if s]
network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))
up_lr_weight=None
if 'up_lr_weight' in kwargs:
up_lr_weight = kwargs.get('up_lr_weight',None)
if "," in up_lr_weight:
up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s]
down_lr_weight=None
if 'down_lr_weight' in kwargs:
down_lr_weight = kwargs.get('down_lr_weight',None)
if "," in down_lr_weight:
down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s]
mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None
network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))
return network
@@ -334,9 +334,9 @@ class LoRANetwork(torch.nn.Module):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
self.up_weight:list[float] = None
self.down_weight:list[float] = None
self.mid_weight:float = None
self.up_lr_weight:list[float] = None
self.down_lr_weight:list[float] = None
self.mid_lr_weight:float = None
def set_multiplier(self, multiplier):
self.multiplier = multiplier
@@ -433,7 +433,7 @@ class LoRANetwork(torch.nn.Module):
print(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0):
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義
if self.apply_to_conv2d_3x3:
max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義
@@ -451,33 +451,33 @@ class LoRANetwork(torch.nn.Module):
elif name=="zeros":
return [0.0] * max_len
else:
print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
return None
if type(down_weight)==str:
down_weight=get_list(down_weight)
if type(up_weight)==str:
up_weight=get_list(up_weight)
if type(down_lr_weight)==str:
down_lr_weight=get_list(down_lr_weight)
if type(up_lr_weight)==str:
up_lr_weight=get_list(up_lr_weight)
if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len):
if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len):
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len)
if (up_weight != None and len(up_weight)<max_len) or (down_weight != None and len(down_weight)<max_len):
if (up_lr_weight != None and len(up_lr_weight)<max_len) or (down_lr_weight != None and len(down_lr_weight)<max_len):
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"%max_len)
if down_weight != None and len(down_weight)<max_len:
down_weight = down_weight + [1.0] * (max_len - len(down_weight))
if up_weight != None and len(up_weight)<max_len:
up_weight = up_weight + [1.0] * (max_len - len(up_weight))
if (up_weight != None) or (mid_weight != None) or (down_weight != None):
if down_lr_weight != None and len(down_lr_weight)<max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight)<max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("層別学習率を適用します。")
if (down_weight != None):
self.down_weight = [w if w > zero_threshold else 0 for w in down_weight[:max_len]]
print("down_weight(浅い層->深い層):",self.down_weight)
if (mid_weight != None):
self.mid_weight = mid_weight if mid_weight > zero_threshold else 0
print("mid_weight:",self.mid_weight)
if (up_weight != None):
self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]]
print("up_weight(深い層->浅い層):",self.up_weight)
if (down_lr_weight != None):
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
if (mid_lr_weight != None):
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:",self.mid_lr_weight)
if (up_lr_weight != None):
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]]
print("up_lr_weight(深い層->浅い層):",self.up_lr_weight)
return
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
@@ -501,12 +501,12 @@ class LoRANetwork(torch.nn.Module):
if g[0]=="up":
idx=i-1
if (g[0]=="up") and (self.up_weight != None):
return self.up_weight[idx]
elif (g[0]=="down") and (self.down_weight != None):
return self.down_weight[idx]
elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None):
return self.mid_weight
if (g[0]=="up") and (self.up_lr_weight != None):
return self.up_lr_weight[idx]
elif (g[0]=="down") and (self.down_lr_weight != None):
return self.down_lr_weight[idx]
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None):
return self.mid_lr_weight
# print({'params': lora.parameters(), 'lr':alpha*lr})
return 1