mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
レイヤー数変更(hako-mikan/sd-webui-lora-block-weight参考)
This commit is contained in:
@@ -337,6 +337,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.up_lr_weight:list[float] = None
|
self.up_lr_weight:list[float] = None
|
||||||
self.down_lr_weight:list[float] = None
|
self.down_lr_weight:list[float] = None
|
||||||
self.mid_lr_weight:float = None
|
self.mid_lr_weight:float = None
|
||||||
|
self.stratified_lr = False
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
@@ -434,10 +435,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||||
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
|
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
|
||||||
max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義
|
max_len=12 # フルモデル相当でのup,downの層の数
|
||||||
if self.apply_to_conv2d_3x3:
|
|
||||||
max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義
|
|
||||||
|
|
||||||
def get_list(name) -> list[float]:
|
def get_list(name) -> list[float]:
|
||||||
import math
|
import math
|
||||||
if name=="cosine":
|
if name=="cosine":
|
||||||
@@ -469,6 +467,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||||
print("層別学習率を適用します。")
|
print("層別学習率を適用します。")
|
||||||
|
self.stratified_lr = True
|
||||||
if (down_lr_weight != None):
|
if (down_lr_weight != None):
|
||||||
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
|
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
|
||||||
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
|
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
|
||||||
@@ -483,31 +482,22 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
|
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
|
||||||
m = RE_UPDOWN.search(lora.lora_name)
|
m = RE_UPDOWN.search(lora.lora_name)
|
||||||
if m:
|
if m:
|
||||||
idx = 0
|
|
||||||
g = m.groups()
|
g = m.groups()
|
||||||
i = int(g[1])
|
i = int(g[1])
|
||||||
if self.apply_to_conv2d_3x3:
|
j = int(g[3])
|
||||||
if g[2]=="resnets":
|
if g[2]=="resnets":
|
||||||
idx=3*i
|
idx=3*i + j
|
||||||
elif g[2]=="attentions":
|
elif g[2]=="attentions":
|
||||||
if g[0]=="down":
|
idx=3*i + j
|
||||||
idx=3*i + 2
|
|
||||||
else:
|
|
||||||
idx=3*i - 1
|
|
||||||
elif g[2]=="upsamplers" or g[2]=="downsamplers":
|
elif g[2]=="upsamplers" or g[2]=="downsamplers":
|
||||||
idx=3*i + 1
|
idx=3*i + 2
|
||||||
else:
|
|
||||||
idx=i
|
|
||||||
if g[0]=="up":
|
|
||||||
idx=i-1
|
|
||||||
|
|
||||||
if (g[0]=="up") and (self.up_lr_weight != None):
|
if (g[0]=="down") and (self.down_lr_weight != None):
|
||||||
|
return self.down_lr_weight[idx+1]
|
||||||
|
elif (g[0]=="up") and (self.up_lr_weight != None):
|
||||||
return self.up_lr_weight[idx]
|
return self.up_lr_weight[idx]
|
||||||
elif (g[0]=="down") and (self.down_lr_weight != None):
|
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
|
||||||
return self.down_lr_weight[idx]
|
|
||||||
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None):
|
|
||||||
return self.mid_lr_weight
|
return self.mid_lr_weight
|
||||||
# print({'params': lora.parameters(), 'lr':alpha*lr})
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
|
||||||
@@ -525,12 +515,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
for lora in self.unet_loras:
|
for lora in self.unet_loras:
|
||||||
param_data={}
|
param_data = {'params': lora.parameters()}
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr}
|
param_data['lr'] = unet_lr
|
||||||
elif default_lr is not None:
|
elif default_lr is not None:
|
||||||
param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr}
|
param_data['lr'] = default_lr
|
||||||
if param_data["lr"]==0:
|
if self.stratified_lr and ('lr' in param_data):
|
||||||
|
param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora)
|
||||||
|
if (param_data['lr']==0):
|
||||||
continue
|
continue
|
||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
return all_params
|
return all_params
|
||||||
|
|||||||
Reference in New Issue
Block a user