From 97e65bf93fb609da4df280e83839087f4743b744 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Apr 2023 16:10:09 +0900 Subject: [PATCH] change 'stratify' to 'block', add en message --- networks/lora.py | 189 ++++++++++++++++++++++++++++------------------- 1 file changed, 115 insertions(+), 74 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 6e860a03..f1a65074 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,7 +12,8 @@ import re from library import train_util -RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_') +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + class LoRAModule(torch.nn.Module): """ @@ -191,18 +192,22 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha=conv_alpha, ) - up_lr_weight=None - if 'up_lr_weight' in kwargs: - up_lr_weight = kwargs.get('up_lr_weight',None) + # if some parameters are not set, use zero + up_lr_weight = kwargs.get("up_lr_weight", None) + if up_lr_weight is not None: if "," in up_lr_weight: - up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s] - down_lr_weight=None - if 'down_lr_weight' in kwargs: - down_lr_weight = kwargs.get('down_lr_weight',None) + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight = kwargs.get("down_lr_weight", None) + if down_lr_weight is not None: if "," in down_lr_weight: - down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s] - mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None - network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + mid_lr_weight = kwargs.get("mid_lr_weight", None) + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))) return network @@ -328,17 +333,17 @@ class LoRANetwork(torch.nn.Module): self.weights_sd = None + self.up_lr_weight: list[float] = None + self.down_lr_weight: list[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - self.up_lr_weight:list[float] = None - self.down_lr_weight:list[float] = None - self.mid_lr_weight:float = None - self.stratified_lr = False - def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -389,13 +394,16 @@ class LoRANetwork(torch.nn.Module): skipped = [] for lora in self.text_encoder_loras + self.unet_loras: - if self.get_stratified_lr_weight(lora) == 0: + if self.block_lr and self.get_block_lr_weight(lora) == 0: skipped.append(lora.lora_name) continue lora.apply_to() self.add_module(lora.lora_name, lora) - if len(skipped)>0: - print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:") + + if len(skipped) > 0: + print( + f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) for name in skipped: print(f"\t{name}") @@ -431,76 +439,109 @@ class LoRANetwork(torch.nn.Module): if key.startswith(lora.lora_name): sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) + print(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する - def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): - max_len=12 # フルモデル相当でのup,downの層の数 + def set_block_lr_weight( + self, + up_lr_weight: list[float] | str = None, + mid_lr_weight: float = None, + down_lr_weight: list[float] | str = None, + zero_threshold: float = 0.0, + ): + # バラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return + + max_len = 12 # フルモデル相当でのup,downの層の数 + def get_list(name) -> list[float]: - import math - if name=="cosine": - return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))] - elif name=="sine": - return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] - elif name=="linear": - return [i/(max_len-1) for i in range(max_len)] - elif name=="reverse_linear": - return [i/(max_len-1) for i in reversed(range(max_len))] - elif name=="zeros": + import math + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) for i in reversed(range(max_len))] + elif name == "zeros": return [0.0] * max_len else: - print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + print( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) return None - if type(down_lr_weight)==str: - down_lr_weight=get_list(down_lr_weight) - if type(up_lr_weight)==str: - up_lr_weight=get_list(up_lr_weight) + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) - if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len): - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) - if (up_lr_weight != None and len(up_lr_weight) max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) zero_threshold else 0 for w in down_lr_weight[:max_len]] - print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) - if (mid_lr_weight != None): + print("apply block learning rate / 階層別学習率を適用します。") + self.block_lr = True + + if down_lr_weight != None: + self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight) + else: + print("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:",self.mid_lr_weight) - if (up_lr_weight != None): - self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]] - print("up_lr_weight(深い層->浅い層):",self.up_lr_weight) + print("mid_lr_weight:", self.mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight) + else: + print("up_lr_weight: all 1.0, すべて1.0") + return - def get_stratified_lr_weight(self, lora:LoRAModule) -> float: + def get_block_lr_weight(self, lora: LoRAModule) -> float: m = RE_UPDOWN.search(lora.lora_name) if m: g = m.groups() i = int(g[1]) j = int(g[3]) - if g[2]=="resnets": - idx=3*i + j - elif g[2]=="attentions": - idx=3*i + j - elif g[2]=="upsamplers" or g[2]=="downsamplers": - idx=3*i + 2 + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 - if (g[0]=="down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx+1] - elif (g[0]=="up") and (self.up_lr_weight != None): + if (g[0] == "down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx + 1] + if (g[0] == "up") and (self.up_lr_weight != None): return self.up_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 - return self.mid_lr_weight + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 + return self.mid_lr_weight return 1 - def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -508,29 +549,29 @@ class LoRANetwork(torch.nn.Module): params = [] for lora in self.text_encoder_loras: params.extend(lora.parameters()) - param_data = {'params': params} + param_data = {"params": params} if text_encoder_lr is not None: - param_data['lr'] = text_encoder_lr + param_data["lr"] = text_encoder_lr all_params.append(param_data) if self.unet_loras: - if self.stratified_lr: + if self.block_lr: for lora in self.unet_loras: - param_data = {'params': lora.parameters()} + param_data = {"params": lora.parameters()} if unet_lr is not None: - param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + param_data["lr"] = unet_lr * self.get_block_lr_weight(lora) elif default_lr is not None: - param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) - if ('lr' in param_data) and (param_data['lr']==0): + param_data["lr"] = default_lr * self.get_block_lr_weight(lora) + if ("lr" in param_data) and (param_data["lr"] == 0): continue all_params.append(param_data) else: params = [] for lora in self.unet_loras: params.extend(lora.parameters()) - param_data = {'params': params} + param_data = {"params": params} if unet_lr is not None: - param_data['lr'] = unet_lr + param_data["lr"] = unet_lr all_params.append(param_data) return all_params