mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
change 'stratify' to 'block', add en message
This commit is contained in:
121
networks/lora.py
121
networks/lora.py
@@ -12,7 +12,8 @@ import re
|
||||
|
||||
from library import train_util
|
||||
|
||||
RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_')
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
@@ -191,18 +192,22 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
conv_alpha=conv_alpha,
|
||||
)
|
||||
|
||||
up_lr_weight=None
|
||||
if 'up_lr_weight' in kwargs:
|
||||
up_lr_weight = kwargs.get('up_lr_weight',None)
|
||||
# if some parameters are not set, use zero
|
||||
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s]
|
||||
down_lr_weight=None
|
||||
if 'down_lr_weight' in kwargs:
|
||||
down_lr_weight = kwargs.get('down_lr_weight',None)
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||
if down_lr_weight is not None:
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s]
|
||||
mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None
|
||||
network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)))
|
||||
|
||||
return network
|
||||
|
||||
@@ -328,17 +333,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
self.weights_sd = None
|
||||
|
||||
self.up_lr_weight: list[float] = None
|
||||
self.down_lr_weight: list[float] = None
|
||||
self.mid_lr_weight: float = None
|
||||
self.block_lr = False
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
self.up_lr_weight:list[float] = None
|
||||
self.down_lr_weight:list[float] = None
|
||||
self.mid_lr_weight:float = None
|
||||
self.stratified_lr = False
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
@@ -389,13 +394,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
skipped = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if self.get_stratified_lr_weight(lora) == 0:
|
||||
if self.block_lr and self.get_block_lr_weight(lora) == 0:
|
||||
skipped.append(lora.lora_name)
|
||||
continue
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
if len(skipped) > 0:
|
||||
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
|
||||
print(
|
||||
f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
print(f"\t{name}")
|
||||
|
||||
@@ -431,13 +439,26 @@ class LoRANetwork(torch.nn.Module):
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
|
||||
def set_block_lr_weight(
|
||||
self,
|
||||
up_lr_weight: list[float] | str = None,
|
||||
mid_lr_weight: float = None,
|
||||
down_lr_weight: list[float] | str = None,
|
||||
zero_threshold: float = 0.0,
|
||||
):
|
||||
# バラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return
|
||||
|
||||
max_len = 12 # フルモデル相当でのup,downの層の数
|
||||
|
||||
def get_list(name) -> list[float]:
|
||||
import math
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))]
|
||||
elif name == "sine":
|
||||
@@ -449,7 +470,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
elif name == "zeros":
|
||||
return [0.0] * max_len
|
||||
else:
|
||||
print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
return None
|
||||
|
||||
if type(down_lr_weight) == str:
|
||||
@@ -458,28 +482,45 @@ class LoRANetwork(torch.nn.Module):
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
print("層別学習率を適用します。")
|
||||
self.stratified_lr = True
|
||||
if (down_lr_weight != None):
|
||||
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
|
||||
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
|
||||
if (mid_lr_weight != None):
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
self.block_lr = True
|
||||
|
||||
if down_lr_weight != None:
|
||||
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight)
|
||||
else:
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
print("mid_lr_weight:", self.mid_lr_weight)
|
||||
if (up_lr_weight != None):
|
||||
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]]
|
||||
print("up_lr_weight(深い層->浅い層):",self.up_lr_weight)
|
||||
else:
|
||||
print("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight)
|
||||
else:
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return
|
||||
|
||||
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
|
||||
def get_block_lr_weight(self, lora: LoRAModule) -> float:
|
||||
m = RE_UPDOWN.search(lora.lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
@@ -494,7 +535,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if (g[0] == "down") and (self.down_lr_weight != None):
|
||||
return self.down_lr_weight[idx + 1]
|
||||
elif (g[0]=="up") and (self.up_lr_weight != None):
|
||||
if (g[0] == "up") and (self.up_lr_weight != None):
|
||||
return self.up_lr_weight[idx]
|
||||
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
|
||||
return self.mid_lr_weight
|
||||
@@ -508,29 +549,29 @@ class LoRANetwork(torch.nn.Module):
|
||||
params = []
|
||||
for lora in self.text_encoder_loras:
|
||||
params.extend(lora.parameters())
|
||||
param_data = {'params': params}
|
||||
param_data = {"params": params}
|
||||
if text_encoder_lr is not None:
|
||||
param_data['lr'] = text_encoder_lr
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
if self.stratified_lr:
|
||||
if self.block_lr:
|
||||
for lora in self.unet_loras:
|
||||
param_data = {'params': lora.parameters()}
|
||||
param_data = {"params": lora.parameters()}
|
||||
if unet_lr is not None:
|
||||
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
|
||||
param_data["lr"] = unet_lr * self.get_block_lr_weight(lora)
|
||||
elif default_lr is not None:
|
||||
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
|
||||
if ('lr' in param_data) and (param_data['lr']==0):
|
||||
param_data["lr"] = default_lr * self.get_block_lr_weight(lora)
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
else:
|
||||
params = []
|
||||
for lora in self.unet_loras:
|
||||
params.extend(lora.parameters())
|
||||
param_data = {'params': params}
|
||||
param_data = {"params": params}
|
||||
if unet_lr is not None:
|
||||
param_data['lr'] = unet_lr
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
Reference in New Issue
Block a user