mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
change 'stratify' to 'block', add en message
This commit is contained in:
205
networks/lora.py
205
networks/lora.py
@@ -12,7 +12,8 @@ import re
|
|||||||
|
|
||||||
from library import train_util
|
from library import train_util
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_')
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -191,18 +192,22 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
up_lr_weight=None
|
# if some parameters are not set, use zero
|
||||||
if 'up_lr_weight' in kwargs:
|
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||||
up_lr_weight = kwargs.get('up_lr_weight',None)
|
if up_lr_weight is not None:
|
||||||
if "," in up_lr_weight:
|
if "," in up_lr_weight:
|
||||||
up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s]
|
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||||
down_lr_weight=None
|
|
||||||
if 'down_lr_weight' in kwargs:
|
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||||
down_lr_weight = kwargs.get('down_lr_weight',None)
|
if down_lr_weight is not None:
|
||||||
if "," in down_lr_weight:
|
if "," in down_lr_weight:
|
||||||
down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s]
|
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||||
mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None
|
|
||||||
network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))
|
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||||
|
if mid_lr_weight is not None:
|
||||||
|
mid_lr_weight = float(mid_lr_weight)
|
||||||
|
|
||||||
|
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)))
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
@@ -328,17 +333,17 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
self.weights_sd = None
|
self.weights_sd = None
|
||||||
|
|
||||||
|
self.up_lr_weight: list[float] = None
|
||||||
|
self.down_lr_weight: list[float] = None
|
||||||
|
self.mid_lr_weight: float = None
|
||||||
|
self.block_lr = False
|
||||||
|
|
||||||
# assertion
|
# assertion
|
||||||
names = set()
|
names = set()
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
self.up_lr_weight:list[float] = None
|
|
||||||
self.down_lr_weight:list[float] = None
|
|
||||||
self.mid_lr_weight:float = None
|
|
||||||
self.stratified_lr = False
|
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
@@ -389,13 +394,16 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
skipped = []
|
skipped = []
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
if self.get_stratified_lr_weight(lora) == 0:
|
if self.block_lr and self.get_block_lr_weight(lora) == 0:
|
||||||
skipped.append(lora.lora_name)
|
skipped.append(lora.lora_name)
|
||||||
continue
|
continue
|
||||||
lora.apply_to()
|
lora.apply_to()
|
||||||
self.add_module(lora.lora_name, lora)
|
self.add_module(lora.lora_name, lora)
|
||||||
if len(skipped)>0:
|
|
||||||
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
|
if len(skipped) > 0:
|
||||||
|
print(
|
||||||
|
f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||||
|
)
|
||||||
for name in skipped:
|
for name in skipped:
|
||||||
print(f"\t{name}")
|
print(f"\t{name}")
|
||||||
|
|
||||||
@@ -431,76 +439,109 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
if key.startswith(lora.lora_name):
|
if key.startswith(lora.lora_name):
|
||||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
||||||
lora.merge_to(sd_for_lora, dtype, device)
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
print(f"weights are merged")
|
print(f"weights are merged")
|
||||||
|
|
||||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||||
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
|
def set_block_lr_weight(
|
||||||
max_len=12 # フルモデル相当でのup,downの層の数
|
self,
|
||||||
def get_list(name) -> list[float]:
|
up_lr_weight: list[float] | str = None,
|
||||||
import math
|
mid_lr_weight: float = None,
|
||||||
if name=="cosine":
|
down_lr_weight: list[float] | str = None,
|
||||||
return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))]
|
zero_threshold: float = 0.0,
|
||||||
elif name=="sine":
|
):
|
||||||
return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)]
|
# バラメータ未指定時は何もせず、今までと同じ動作とする
|
||||||
elif name=="linear":
|
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||||
return [i/(max_len-1) for i in range(max_len)]
|
|
||||||
elif name=="reverse_linear":
|
|
||||||
return [i/(max_len-1) for i in reversed(range(max_len))]
|
|
||||||
elif name=="zeros":
|
|
||||||
return [0.0] * max_len
|
|
||||||
else:
|
|
||||||
print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
|
|
||||||
return None
|
|
||||||
|
|
||||||
if type(down_lr_weight)==str:
|
|
||||||
down_lr_weight=get_list(down_lr_weight)
|
|
||||||
if type(up_lr_weight)==str:
|
|
||||||
up_lr_weight=get_list(up_lr_weight)
|
|
||||||
|
|
||||||
if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len):
|
|
||||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len)
|
|
||||||
if (up_lr_weight != None and len(up_lr_weight)<max_len) or (down_lr_weight != None and len(down_lr_weight)<max_len):
|
|
||||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"%max_len)
|
|
||||||
if down_lr_weight != None and len(down_lr_weight)<max_len:
|
|
||||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
|
||||||
if up_lr_weight != None and len(up_lr_weight)<max_len:
|
|
||||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
|
||||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
|
||||||
print("層別学習率を適用します。")
|
|
||||||
self.stratified_lr = True
|
|
||||||
if (down_lr_weight != None):
|
|
||||||
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
|
|
||||||
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
|
|
||||||
if (mid_lr_weight != None):
|
|
||||||
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
|
||||||
print("mid_lr_weight:",self.mid_lr_weight)
|
|
||||||
if (up_lr_weight != None):
|
|
||||||
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]]
|
|
||||||
print("up_lr_weight(深い層->浅い層):",self.up_lr_weight)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
|
max_len = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
|
def get_list(name) -> list[float]:
|
||||||
|
import math
|
||||||
|
|
||||||
|
if name == "cosine":
|
||||||
|
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))]
|
||||||
|
elif name == "sine":
|
||||||
|
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)]
|
||||||
|
elif name == "linear":
|
||||||
|
return [i / (max_len - 1) for i in range(max_len)]
|
||||||
|
elif name == "reverse_linear":
|
||||||
|
return [i / (max_len - 1) for i in reversed(range(max_len))]
|
||||||
|
elif name == "zeros":
|
||||||
|
return [0.0] * max_len
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||||
|
% (name)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if type(down_lr_weight) == str:
|
||||||
|
down_lr_weight = get_list(down_lr_weight)
|
||||||
|
if type(up_lr_weight) == str:
|
||||||
|
up_lr_weight = get_list(up_lr_weight)
|
||||||
|
|
||||||
|
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||||
|
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||||
|
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||||
|
up_lr_weight = up_lr_weight[:max_len]
|
||||||
|
down_lr_weight = down_lr_weight[:max_len]
|
||||||
|
|
||||||
|
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||||
|
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||||
|
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||||
|
|
||||||
|
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||||
|
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||||
|
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||||
|
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||||
|
|
||||||
|
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||||
|
print("apply block learning rate / 階層別学習率を適用します。")
|
||||||
|
self.block_lr = True
|
||||||
|
|
||||||
|
if down_lr_weight != None:
|
||||||
|
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||||
|
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight)
|
||||||
|
else:
|
||||||
|
print("down_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
|
if mid_lr_weight != None:
|
||||||
|
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||||
|
print("mid_lr_weight:", self.mid_lr_weight)
|
||||||
|
else:
|
||||||
|
print("mid_lr_weight: 1.0")
|
||||||
|
|
||||||
|
if up_lr_weight != None:
|
||||||
|
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||||
|
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight)
|
||||||
|
else:
|
||||||
|
print("up_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_block_lr_weight(self, lora: LoRAModule) -> float:
|
||||||
m = RE_UPDOWN.search(lora.lora_name)
|
m = RE_UPDOWN.search(lora.lora_name)
|
||||||
if m:
|
if m:
|
||||||
g = m.groups()
|
g = m.groups()
|
||||||
i = int(g[1])
|
i = int(g[1])
|
||||||
j = int(g[3])
|
j = int(g[3])
|
||||||
if g[2]=="resnets":
|
if g[2] == "resnets":
|
||||||
idx=3*i + j
|
idx = 3 * i + j
|
||||||
elif g[2]=="attentions":
|
elif g[2] == "attentions":
|
||||||
idx=3*i + j
|
idx = 3 * i + j
|
||||||
elif g[2]=="upsamplers" or g[2]=="downsamplers":
|
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||||
idx=3*i + 2
|
idx = 3 * i + 2
|
||||||
|
|
||||||
if (g[0]=="down") and (self.down_lr_weight != None):
|
if (g[0] == "down") and (self.down_lr_weight != None):
|
||||||
return self.down_lr_weight[idx+1]
|
return self.down_lr_weight[idx + 1]
|
||||||
elif (g[0]=="up") and (self.up_lr_weight != None):
|
if (g[0] == "up") and (self.up_lr_weight != None):
|
||||||
return self.up_lr_weight[idx]
|
return self.up_lr_weight[idx]
|
||||||
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
|
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
|
||||||
return self.mid_lr_weight
|
return self.mid_lr_weight
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
@@ -508,29 +549,29 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
params = []
|
params = []
|
||||||
for lora in self.text_encoder_loras:
|
for lora in self.text_encoder_loras:
|
||||||
params.extend(lora.parameters())
|
params.extend(lora.parameters())
|
||||||
param_data = {'params': params}
|
param_data = {"params": params}
|
||||||
if text_encoder_lr is not None:
|
if text_encoder_lr is not None:
|
||||||
param_data['lr'] = text_encoder_lr
|
param_data["lr"] = text_encoder_lr
|
||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
if self.stratified_lr:
|
if self.block_lr:
|
||||||
for lora in self.unet_loras:
|
for lora in self.unet_loras:
|
||||||
param_data = {'params': lora.parameters()}
|
param_data = {"params": lora.parameters()}
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
|
param_data["lr"] = unet_lr * self.get_block_lr_weight(lora)
|
||||||
elif default_lr is not None:
|
elif default_lr is not None:
|
||||||
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
|
param_data["lr"] = default_lr * self.get_block_lr_weight(lora)
|
||||||
if ('lr' in param_data) and (param_data['lr']==0):
|
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||||
continue
|
continue
|
||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
else:
|
else:
|
||||||
params = []
|
params = []
|
||||||
for lora in self.unet_loras:
|
for lora in self.unet_loras:
|
||||||
params.extend(lora.parameters())
|
params.extend(lora.parameters())
|
||||||
param_data = {'params': params}
|
param_data = {"params": params}
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data['lr'] = unet_lr
|
param_data["lr"] = unet_lr
|
||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|||||||
Reference in New Issue
Block a user