From 058e442072582f16a591bf5fb5f395f953767501 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 2 Apr 2023 04:02:34 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=AC=E3=82=A4=E3=83=A4=E3=83=BC=E6=95=B0?= =?UTF-8?q?=E5=A4=89=E6=9B=B4(hako-mikan/sd-webui-lora-block-weight?= =?UTF-8?q?=E5=8F=82=E8=80=83)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 52 ++++++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index bb8f356e..cfc517ce 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -337,6 +337,7 @@ class LoRANetwork(torch.nn.Module): self.up_lr_weight:list[float] = None self.down_lr_weight:list[float] = None self.mid_lr_weight:float = None + self.stratified_lr = False def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -434,10 +435,7 @@ class LoRANetwork(torch.nn.Module): # 層別学習率用に層ごとの学習率に対する倍率を定義する def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): - max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 - if self.apply_to_conv2d_3x3: - max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 - + max_len=12 # フルモデル相当でのup,downの層の数 def get_list(name) -> list[float]: import math if name=="cosine": @@ -469,6 +467,7 @@ class LoRANetwork(torch.nn.Module): up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): print("層別学習率を適用します。") + self.stratified_lr = True if (down_lr_weight != None): self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]] print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) @@ -483,31 +482,22 @@ class LoRANetwork(torch.nn.Module): def get_stratified_lr_weight(self, lora:LoRAModule) -> float: m = RE_UPDOWN.search(lora.lora_name) if m: - idx = 0 g = m.groups() i = int(g[1]) - if self.apply_to_conv2d_3x3: - if g[2]=="resnets": - idx=3*i - elif g[2]=="attentions": - if g[0]=="down": - idx=3*i + 2 - else: - idx=3*i - 1 - elif g[2]=="upsamplers" or g[2]=="downsamplers": - idx=3*i + 1 - else: - idx=i - if g[0]=="up": - idx=i-1 + j = int(g[3]) + if g[2]=="resnets": + idx=3*i + j + elif g[2]=="attentions": + idx=3*i + j + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 2 - if (g[0]=="up") and (self.up_lr_weight != None): + if (g[0]=="down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx+1] + elif (g[0]=="up") and (self.up_lr_weight != None): return self.up_lr_weight[idx] - elif (g[0]=="down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): - return self.mid_lr_weight - # print({'params': lora.parameters(), 'lr':alpha*lr}) + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 + return self.mid_lr_weight return 1 def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): @@ -525,13 +515,15 @@ class LoRANetwork(torch.nn.Module): if self.unet_loras: for lora in self.unet_loras: - param_data={} + param_data = {'params': lora.parameters()} if unet_lr is not None: - param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr} + param_data['lr'] = unet_lr elif default_lr is not None: - param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr} - if param_data["lr"]==0: - continue + param_data['lr'] = default_lr + if self.stratified_lr and ('lr' in param_data): + param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) + if (param_data['lr']==0): + continue all_params.append(param_data) return all_params