diff --git a/networks/lora.py b/networks/lora.py index 2bf78511..4dbf79f9 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -8,9 +8,11 @@ import os from typing import List import numpy as np import torch +import re from library import train_util +RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_') class LoRAModule(torch.nn.Module): """ @@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un else: conv_block_alphas = [int(a) for a in conv_block_alphas(',')] assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - """ + """ network = LoRANetwork( text_encoder, @@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_lora_dim=conv_dim, conv_alpha=conv_alpha, ) + + up_weight=None + if 'up_weight' in kwargs: + up_weight = kwargs.get('up_weight',None) + if "," in up_weight: + up_weight = [float(s) for s in up_weight.split(",") if s] + down_weight=None + if 'down_weight' in kwargs: + down_weight = kwargs.get('down_weight',None) + if "," in down_weight: + down_weight = [float(s) for s in down_weight.split(",") if s] + + network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('lr_weight_threshold', 0.0))) + return network @@ -318,6 +334,10 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + self.up_weight:list[float] = None + self.down_weight:list[float] = None + self.mid_weight:float = None + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -366,9 +386,17 @@ class LoRANetwork(torch.nn.Module): else: self.unet_loras = [] + skipped = [] for lora in self.text_encoder_loras + self.unet_loras: + if self.get_stratified_lr_weight(lora) == 0: + skipped.append(lora.lora_name) + continue lora.apply_to() self.add_module(lora.lora_name, lora) + if len(skipped)>0: + print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:") + for name in skipped: + print(f"\t{name}") if self.weights_sd: # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) @@ -404,34 +432,113 @@ class LoRANetwork(torch.nn.Module): lora.merge_to(sd_for_lora, dtype, device) print(f"weights are merged") - def enable_gradient_checkpointing(self): - # not supported - pass + # 層別学習率用に層ごとの学習率に対する倍率を定義する + def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0): + max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 + if self.apply_to_conv2d_3x3: + max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 - def prepare_optimizer_params(self, text_encoder_lr, unet_lr): - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params + def get_list(name) -> list[float]: + import math + if name=="cosine": + return [math.cos(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="sine": + return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="linear": + return [i/(max_len-1) for i in range(max_len)] + elif name=="reverse_linear": + return [i/(max_len-1) for i in reversed(range(max_len))] + elif name=="zeros": + return [0.0] * max_len + else: + print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + return None + if type(down_weight)==str: + down_weight=get_list(down_weight) + if type(up_weight)==str: + up_weight=get_list(up_weight) + + if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len): + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) + if (up_weight != None and len(up_weight) zero_threshold else 0 for w in down_weight[:max_len]] + print("down_weight(浅い層->深い層):",self.down_weight) + if (mid_weight != None): + self.mid_weight = mid_weight if mid_weight > zero_threshold else 0 + print("mid_weight:",self.mid_weight) + if (up_weight != None): + self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]] + print("up_weight(深い層->浅い層):",self.up_weight) + return + + def get_stratified_lr_weight(self, lora:LoRAModule) -> float: + m = RE_UPDOWN.search(lora.lora_name) + if m: + idx = 0 + g = m.groups() + i = int(g[1]) + if self.apply_to_conv2d_3x3: + if g[2]=="resnets": + idx=3*i + elif g[2]=="attentions": + if g[0]=="down": + idx=3*i + 2 + else: + idx=3*i - 1 + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 1 + else: + idx=i + if g[0]=="up": + idx=i-1 + + if (g[0]=="up") and (self.up_weight != None): + return self.up_weight[idx] + elif (g[0]=="down") and (self.down_weight != None): + return self.down_weight[idx] + elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None): + return self.mid_weight + # print({'params': lora.parameters(), 'lr':alpha*lr}) + return 1 + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): self.requires_grad_(True) all_params = [] if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} + params = [] + for lora in self.text_encoder_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr + param_data['lr'] = text_encoder_lr all_params.append(param_data) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - + for lora in self.unet_loras: + param_data={} + if unet_lr is not None: + param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr} + elif default_lr is not None: + param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr} + if param_data["lr"]==0: + continue + all_params.append(param_data) return all_params + def enable_gradient_checkpointing(self): + # not supported + pass + def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/train_network.py b/train_network.py index 200d8d84..eb5301e2 100644 --- a/train_network.py +++ b/train_network.py @@ -191,7 +191,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する