From b1dffe8d9ae1c02a06e8871a844c42d6729623ce Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 00:11:11 +0900 Subject: [PATCH 01/24] =?UTF-8?q?=E3=83=95=E3=82=A1=E3=82=A4=E3=83=AB?= =?UTF-8?q?=E3=83=AD=E3=83=BC=E3=83=89=E3=81=8C=E3=81=A7=E3=81=8D=E3=81=AA?= =?UTF-8?q?=E3=81=84=E3=83=90=E3=82=B0=E4=BF=AE=E6=AD=A3(Exception:=20devi?= =?UTF-8?q?ce=20cuda=20is=20invalid)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/model_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index e227ced8..f3f236af 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -831,7 +831,7 @@ def is_safetensors(path): return os.path.splitext(path)[1].lower() == '.safetensors' -def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), @@ -866,7 +866,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) # no need to specify device in loading state_dict # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2) From 4dacc52bde623e7a562d054585250ba8a7737c0a Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 00:39:35 +0900 Subject: [PATCH 02/24] implement stratified_lr --- networks/lora.py | 141 +++++++++++++++++++++++++++++++++++++++++------ train_network.py | 2 +- 2 files changed, 125 insertions(+), 18 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 2bf78511..4dbf79f9 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -8,9 +8,11 @@ import os from typing import List import numpy as np import torch +import re from library import train_util +RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_') class LoRAModule(torch.nn.Module): """ @@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un else: conv_block_alphas = [int(a) for a in conv_block_alphas(',')] assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - """ + """ network = LoRANetwork( text_encoder, @@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_lora_dim=conv_dim, conv_alpha=conv_alpha, ) + + up_weight=None + if 'up_weight' in kwargs: + up_weight = kwargs.get('up_weight',None) + if "," in up_weight: + up_weight = [float(s) for s in up_weight.split(",") if s] + down_weight=None + if 'down_weight' in kwargs: + down_weight = kwargs.get('down_weight',None) + if "," in down_weight: + down_weight = [float(s) for s in down_weight.split(",") if s] + + network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('lr_weight_threshold', 0.0))) + return network @@ -318,6 +334,10 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + self.up_weight:list[float] = None + self.down_weight:list[float] = None + self.mid_weight:float = None + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -366,9 +386,17 @@ class LoRANetwork(torch.nn.Module): else: self.unet_loras = [] + skipped = [] for lora in self.text_encoder_loras + self.unet_loras: + if self.get_stratified_lr_weight(lora) == 0: + skipped.append(lora.lora_name) + continue lora.apply_to() self.add_module(lora.lora_name, lora) + if len(skipped)>0: + print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:") + for name in skipped: + print(f"\t{name}") if self.weights_sd: # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) @@ -404,34 +432,113 @@ class LoRANetwork(torch.nn.Module): lora.merge_to(sd_for_lora, dtype, device) print(f"weights are merged") - def enable_gradient_checkpointing(self): - # not supported - pass + # 層別学習率用に層ごとの学習率に対する倍率を定義する + def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0): + max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 + if self.apply_to_conv2d_3x3: + max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 - def prepare_optimizer_params(self, text_encoder_lr, unet_lr): - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params + def get_list(name) -> list[float]: + import math + if name=="cosine": + return [math.cos(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="sine": + return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="linear": + return [i/(max_len-1) for i in range(max_len)] + elif name=="reverse_linear": + return [i/(max_len-1) for i in reversed(range(max_len))] + elif name=="zeros": + return [0.0] * max_len + else: + print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + return None + if type(down_weight)==str: + down_weight=get_list(down_weight) + if type(up_weight)==str: + up_weight=get_list(up_weight) + + if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len): + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) + if (up_weight != None and len(up_weight) zero_threshold else 0 for w in down_weight[:max_len]] + print("down_weight(浅い層->深い層):",self.down_weight) + if (mid_weight != None): + self.mid_weight = mid_weight if mid_weight > zero_threshold else 0 + print("mid_weight:",self.mid_weight) + if (up_weight != None): + self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]] + print("up_weight(深い層->浅い層):",self.up_weight) + return + + def get_stratified_lr_weight(self, lora:LoRAModule) -> float: + m = RE_UPDOWN.search(lora.lora_name) + if m: + idx = 0 + g = m.groups() + i = int(g[1]) + if self.apply_to_conv2d_3x3: + if g[2]=="resnets": + idx=3*i + elif g[2]=="attentions": + if g[0]=="down": + idx=3*i + 2 + else: + idx=3*i - 1 + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 1 + else: + idx=i + if g[0]=="up": + idx=i-1 + + if (g[0]=="up") and (self.up_weight != None): + return self.up_weight[idx] + elif (g[0]=="down") and (self.down_weight != None): + return self.down_weight[idx] + elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None): + return self.mid_weight + # print({'params': lora.parameters(), 'lr':alpha*lr}) + return 1 + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): self.requires_grad_(True) all_params = [] if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} + params = [] + for lora in self.text_encoder_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr + param_data['lr'] = text_encoder_lr all_params.append(param_data) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - + for lora in self.unet_loras: + param_data={} + if unet_lr is not None: + param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr} + elif default_lr is not None: + param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr} + if param_data["lr"]==0: + continue + all_params.append(param_data) return all_params + def enable_gradient_checkpointing(self): + # not supported + pass + def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/train_network.py b/train_network.py index 200d8d84..eb5301e2 100644 --- a/train_network.py +++ b/train_network.py @@ -191,7 +191,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する From 313f3e82862078319a400bb163e5980171d16b12 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 30 Mar 2023 12:08:04 -0400 Subject: [PATCH 03/24] Open dataset_config json file before load --- library/config_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index 97bbb4a8..3064117f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -486,7 +486,8 @@ def load_user_config(file: str) -> dict: if file.name.lower().endswith('.json'): try: - config = json.load(file) + with open(file, 'r') as f: + config = json.load(f) except Exception: print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise From dade23a4149494e8eb9342463aba13a6f6e04b98 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 01:14:03 +0900 Subject: [PATCH 04/24] =?UTF-8?q?stratified=5Fzero=5Fthreshold=E3=81=AB?= =?UTF-8?q?=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 4dbf79f9..ad8331c8 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -202,7 +202,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un if "," in down_weight: down_weight = [float(s) for s in down_weight.split(",") if s] - network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('lr_weight_threshold', 0.0))) + network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) return network From 1b75dbd4f2553bdc09fdfc1d10fa007926a907b5 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 01:40:29 +0900 Subject: [PATCH 05/24] =?UTF-8?q?=E5=BC=95=E6=95=B0=E5=90=8D=E3=81=AB=5Flr?= =?UTF-8?q?=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 86 ++++++++++++++++++++++++------------------------ 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index ad8331c8..f60789f8 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -191,18 +191,18 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha=conv_alpha, ) - up_weight=None - if 'up_weight' in kwargs: - up_weight = kwargs.get('up_weight',None) - if "," in up_weight: - up_weight = [float(s) for s in up_weight.split(",") if s] - down_weight=None - if 'down_weight' in kwargs: - down_weight = kwargs.get('down_weight',None) - if "," in down_weight: - down_weight = [float(s) for s in down_weight.split(",") if s] - - network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) + up_lr_weight=None + if 'up_lr_weight' in kwargs: + up_lr_weight = kwargs.get('up_lr_weight',None) + if "," in up_lr_weight: + up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s] + down_lr_weight=None + if 'down_lr_weight' in kwargs: + down_lr_weight = kwargs.get('down_lr_weight',None) + if "," in down_lr_weight: + down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s] + mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None + network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) return network @@ -334,9 +334,9 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - self.up_weight:list[float] = None - self.down_weight:list[float] = None - self.mid_weight:float = None + self.up_lr_weight:list[float] = None + self.down_lr_weight:list[float] = None + self.mid_lr_weight:float = None def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -433,7 +433,7 @@ class LoRANetwork(torch.nn.Module): print(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する - def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0): + def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 if self.apply_to_conv2d_3x3: max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 @@ -451,33 +451,33 @@ class LoRANetwork(torch.nn.Module): elif name=="zeros": return [0.0] * max_len else: - print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) return None - if type(down_weight)==str: - down_weight=get_list(down_weight) - if type(up_weight)==str: - up_weight=get_list(up_weight) + if type(down_lr_weight)==str: + down_lr_weight=get_list(down_lr_weight) + if type(up_lr_weight)==str: + up_lr_weight=get_list(up_lr_weight) - if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len): + if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len): print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) - if (up_weight != None and len(up_weight) zero_threshold else 0 for w in down_weight[:max_len]] - print("down_weight(浅い層->深い層):",self.down_weight) - if (mid_weight != None): - self.mid_weight = mid_weight if mid_weight > zero_threshold else 0 - print("mid_weight:",self.mid_weight) - if (up_weight != None): - self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]] - print("up_weight(深い層->浅い層):",self.up_weight) + if (down_lr_weight != None): + self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]] + print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) + if (mid_lr_weight != None): + self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:",self.mid_lr_weight) + if (up_lr_weight != None): + self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]] + print("up_lr_weight(深い層->浅い層):",self.up_lr_weight) return def get_stratified_lr_weight(self, lora:LoRAModule) -> float: @@ -501,12 +501,12 @@ class LoRANetwork(torch.nn.Module): if g[0]=="up": idx=i-1 - if (g[0]=="up") and (self.up_weight != None): - return self.up_weight[idx] - elif (g[0]=="down") and (self.down_weight != None): - return self.down_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None): - return self.mid_weight + if (g[0]=="up") and (self.up_lr_weight != None): + return self.up_lr_weight[idx] + elif (g[0]=="down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx] + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): + return self.mid_lr_weight # print({'params': lora.parameters(), 'lr':alpha*lr}) return 1 From 3032a47af4ce9d08628561f1b759975951d3ddc3 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 01:42:57 +0900 Subject: [PATCH 06/24] =?UTF-8?q?cosine=E3=82=92sine=E3=81=AEreversed?= =?UTF-8?q?=E3=81=AB=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index f60789f8..bb8f356e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -441,7 +441,7 @@ class LoRANetwork(torch.nn.Module): def get_list(name) -> list[float]: import math if name=="cosine": - return [math.cos(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))] elif name=="sine": return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] elif name=="linear": From ccb0ef518a912620812f69d5c125565f160965c3 Mon Sep 17 00:00:00 2001 From: Atsumu Ono Date: Fri, 31 Mar 2023 01:45:49 +0900 Subject: [PATCH 07/24] fix typo --- train_README-ja.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_README-ja.md b/train_README-ja.md index d5f1b5fc..032e006b 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -801,7 +801,7 @@ model_dirオプションでモデルの保存先フォルダを指定できま キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。 ``` -python merge_captions_to_metadata.py --full_apth <教師データフォルダ> +python merge_captions_to_metadata.py --full_path <教師データフォルダ>   --in_json <読み込むメタデータファイル名> <メタデータファイル名> ``` From 94441fa7468b90571e3c4107758639f3e441ee13 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 02:26:54 +0900 Subject: [PATCH 08/24] =?UTF-8?q?=E7=B9=B0=E3=82=8A=E8=BF=94=E3=81=97?= =?UTF-8?q?=E5=9B=9E=E6=95=B0=E3=81=AE=E3=81=AA=E3=81=84=E3=83=87=E3=82=A3?= =?UTF-8?q?=E3=83=AC=E3=82=AF=E3=83=88=E3=83=AA=E3=81=AE=E5=90=8D=E5=89=8D?= =?UTF-8?q?=E8=A1=A8=E7=A4=BA=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/config_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index 97bbb4a8..21764675 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -445,7 +445,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") return 0, "" caption_by_folder = '_'.join(tokens[1:]) return n_repeats, caption_by_folder From 1e164b6ec37eff1034c213628dfc75105922b233 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 12:52:39 +0900 Subject: [PATCH 09/24] specify device when loading state_dict --- library/model_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/model_util.py b/library/model_util.py index 9b4405eb..32a9c87a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -866,7 +866,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) # no need to specify device in loading state_dict + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2) From 9577a9f38d74a96f3d50a4caec5a3bc578e10169 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 1 Apr 2023 20:33:20 +0900 Subject: [PATCH 10/24] Check needless num_warmup_steps --- library/train_util.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 59dbc44c..a195faac 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2460,7 +2460,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps = args.lr_warmup_steps + num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -2484,6 +2484,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): lr_scheduler_kwargs[key] = value + def wrap_check_needless_num_warmup_steps(return_vals): + if num_warmup_steps is not None and num_warmup_steps != 0: + raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") + return return_vals + # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type @@ -2496,7 +2501,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): lr_scheduler_type = values[-1] lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) - return lr_scheduler + return wrap_check_needless_num_warmup_steps(lr_scheduler) if name.startswith("adafactor"): assert ( @@ -2504,12 +2509,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) # print("adafactor scheduler init lr", initial_lr) - return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) + return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer)) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: From 058e442072582f16a591bf5fb5f395f953767501 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 2 Apr 2023 04:02:34 +0900 Subject: [PATCH 11/24] =?UTF-8?q?=E3=83=AC=E3=82=A4=E3=83=A4=E3=83=BC?= =?UTF-8?q?=E6=95=B0=E5=A4=89=E6=9B=B4(hako-mikan/sd-webui-lora-block-weig?= =?UTF-8?q?ht=E5=8F=82=E8=80=83)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 52 ++++++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index bb8f356e..cfc517ce 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -337,6 +337,7 @@ class LoRANetwork(torch.nn.Module): self.up_lr_weight:list[float] = None self.down_lr_weight:list[float] = None self.mid_lr_weight:float = None + self.stratified_lr = False def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -434,10 +435,7 @@ class LoRANetwork(torch.nn.Module): # 層別学習率用に層ごとの学習率に対する倍率を定義する def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): - max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 - if self.apply_to_conv2d_3x3: - max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 - + max_len=12 # フルモデル相当でのup,downの層の数 def get_list(name) -> list[float]: import math if name=="cosine": @@ -469,6 +467,7 @@ class LoRANetwork(torch.nn.Module): up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): print("層別学習率を適用します。") + self.stratified_lr = True if (down_lr_weight != None): self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]] print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) @@ -483,31 +482,22 @@ class LoRANetwork(torch.nn.Module): def get_stratified_lr_weight(self, lora:LoRAModule) -> float: m = RE_UPDOWN.search(lora.lora_name) if m: - idx = 0 g = m.groups() i = int(g[1]) - if self.apply_to_conv2d_3x3: - if g[2]=="resnets": - idx=3*i - elif g[2]=="attentions": - if g[0]=="down": - idx=3*i + 2 - else: - idx=3*i - 1 - elif g[2]=="upsamplers" or g[2]=="downsamplers": - idx=3*i + 1 - else: - idx=i - if g[0]=="up": - idx=i-1 + j = int(g[3]) + if g[2]=="resnets": + idx=3*i + j + elif g[2]=="attentions": + idx=3*i + j + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 2 - if (g[0]=="up") and (self.up_lr_weight != None): + if (g[0]=="down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx+1] + elif (g[0]=="up") and (self.up_lr_weight != None): return self.up_lr_weight[idx] - elif (g[0]=="down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): - return self.mid_lr_weight - # print({'params': lora.parameters(), 'lr':alpha*lr}) + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 + return self.mid_lr_weight return 1 def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): @@ -525,13 +515,15 @@ class LoRANetwork(torch.nn.Module): if self.unet_loras: for lora in self.unet_loras: - param_data={} + param_data = {'params': lora.parameters()} if unet_lr is not None: - param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr} + param_data['lr'] = unet_lr elif default_lr is not None: - param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr} - if param_data["lr"]==0: - continue + param_data['lr'] = default_lr + if self.stratified_lr and ('lr' in param_data): + param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) + if (param_data['lr']==0): + continue all_params.append(param_data) return all_params From 19340d82e6fb2a081cadb5fc4c6f38aa627ea81d Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 2 Apr 2023 12:57:55 +0900 Subject: [PATCH 12/24] =?UTF-8?q?=E5=B1=A4=E5=88=A5=E5=AD=A6=E7=BF=92?= =?UTF-8?q?=E7=8E=87=E3=82=92=E4=BD=BF=E3=82=8F=E3=81=AA=E3=81=84=E5=A0=B4?= =?UTF-8?q?=E5=90=88=E3=81=ABparams=E3=82=92=E3=81=BE=E3=81=A8=E3=82=81?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index cfc517ce..6e860a03 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -514,17 +514,25 @@ class LoRANetwork(torch.nn.Module): all_params.append(param_data) if self.unet_loras: - for lora in self.unet_loras: - param_data = {'params': lora.parameters()} + if self.stratified_lr: + for lora in self.unet_loras: + param_data = {'params': lora.parameters()} + if unet_lr is not None: + param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + elif default_lr is not None: + param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) + if ('lr' in param_data) and (param_data['lr']==0): + continue + all_params.append(param_data) + else: + params = [] + for lora in self.unet_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if unet_lr is not None: param_data['lr'] = unet_lr - elif default_lr is not None: - param_data['lr'] = default_lr - if self.stratified_lr and ('lr' in param_data): - param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) - if (param_data['lr']==0): - continue all_params.append(param_data) + return all_params def enable_gradient_checkpointing(self): From 97e65bf93fb609da4df280e83839087f4743b744 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Apr 2023 16:10:09 +0900 Subject: [PATCH 13/24] change 'stratify' to 'block', add en message --- networks/lora.py | 189 ++++++++++++++++++++++++++++------------------- 1 file changed, 115 insertions(+), 74 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 6e860a03..f1a65074 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,7 +12,8 @@ import re from library import train_util -RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_') +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + class LoRAModule(torch.nn.Module): """ @@ -191,18 +192,22 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha=conv_alpha, ) - up_lr_weight=None - if 'up_lr_weight' in kwargs: - up_lr_weight = kwargs.get('up_lr_weight',None) + # if some parameters are not set, use zero + up_lr_weight = kwargs.get("up_lr_weight", None) + if up_lr_weight is not None: if "," in up_lr_weight: - up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s] - down_lr_weight=None - if 'down_lr_weight' in kwargs: - down_lr_weight = kwargs.get('down_lr_weight',None) + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight = kwargs.get("down_lr_weight", None) + if down_lr_weight is not None: if "," in down_lr_weight: - down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s] - mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None - network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + mid_lr_weight = kwargs.get("mid_lr_weight", None) + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))) return network @@ -328,17 +333,17 @@ class LoRANetwork(torch.nn.Module): self.weights_sd = None + self.up_lr_weight: list[float] = None + self.down_lr_weight: list[float] = None + self.mid_lr_weight: float = None + self.block_lr = False + # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - self.up_lr_weight:list[float] = None - self.down_lr_weight:list[float] = None - self.mid_lr_weight:float = None - self.stratified_lr = False - def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -389,13 +394,16 @@ class LoRANetwork(torch.nn.Module): skipped = [] for lora in self.text_encoder_loras + self.unet_loras: - if self.get_stratified_lr_weight(lora) == 0: + if self.block_lr and self.get_block_lr_weight(lora) == 0: skipped.append(lora.lora_name) continue lora.apply_to() self.add_module(lora.lora_name, lora) - if len(skipped)>0: - print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:") + + if len(skipped) > 0: + print( + f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) for name in skipped: print(f"\t{name}") @@ -431,76 +439,109 @@ class LoRANetwork(torch.nn.Module): if key.startswith(lora.lora_name): sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) + print(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する - def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): - max_len=12 # フルモデル相当でのup,downの層の数 + def set_block_lr_weight( + self, + up_lr_weight: list[float] | str = None, + mid_lr_weight: float = None, + down_lr_weight: list[float] | str = None, + zero_threshold: float = 0.0, + ): + # バラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return + + max_len = 12 # フルモデル相当でのup,downの層の数 + def get_list(name) -> list[float]: - import math - if name=="cosine": - return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))] - elif name=="sine": - return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] - elif name=="linear": - return [i/(max_len-1) for i in range(max_len)] - elif name=="reverse_linear": - return [i/(max_len-1) for i in reversed(range(max_len))] - elif name=="zeros": + import math + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) for i in reversed(range(max_len))] + elif name == "zeros": return [0.0] * max_len else: - print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + print( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) return None - if type(down_lr_weight)==str: - down_lr_weight=get_list(down_lr_weight) - if type(up_lr_weight)==str: - up_lr_weight=get_list(up_lr_weight) + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) - if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len): - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) - if (up_lr_weight != None and len(up_lr_weight) max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) zero_threshold else 0 for w in down_lr_weight[:max_len]] - print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) - if (mid_lr_weight != None): + print("apply block learning rate / 階層別学習率を適用します。") + self.block_lr = True + + if down_lr_weight != None: + self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight) + else: + print("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:",self.mid_lr_weight) - if (up_lr_weight != None): - self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]] - print("up_lr_weight(深い層->浅い層):",self.up_lr_weight) + print("mid_lr_weight:", self.mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight) + else: + print("up_lr_weight: all 1.0, すべて1.0") + return - def get_stratified_lr_weight(self, lora:LoRAModule) -> float: + def get_block_lr_weight(self, lora: LoRAModule) -> float: m = RE_UPDOWN.search(lora.lora_name) if m: g = m.groups() i = int(g[1]) j = int(g[3]) - if g[2]=="resnets": - idx=3*i + j - elif g[2]=="attentions": - idx=3*i + j - elif g[2]=="upsamplers" or g[2]=="downsamplers": - idx=3*i + 2 + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 - if (g[0]=="down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx+1] - elif (g[0]=="up") and (self.up_lr_weight != None): + if (g[0] == "down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx + 1] + if (g[0] == "up") and (self.up_lr_weight != None): return self.up_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 - return self.mid_lr_weight + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 + return self.mid_lr_weight return 1 - def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -508,29 +549,29 @@ class LoRANetwork(torch.nn.Module): params = [] for lora in self.text_encoder_loras: params.extend(lora.parameters()) - param_data = {'params': params} + param_data = {"params": params} if text_encoder_lr is not None: - param_data['lr'] = text_encoder_lr + param_data["lr"] = text_encoder_lr all_params.append(param_data) if self.unet_loras: - if self.stratified_lr: + if self.block_lr: for lora in self.unet_loras: - param_data = {'params': lora.parameters()} + param_data = {"params": lora.parameters()} if unet_lr is not None: - param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + param_data["lr"] = unet_lr * self.get_block_lr_weight(lora) elif default_lr is not None: - param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) - if ('lr' in param_data) and (param_data['lr']==0): + param_data["lr"] = default_lr * self.get_block_lr_weight(lora) + if ("lr" in param_data) and (param_data["lr"] == 0): continue all_params.append(param_data) else: params = [] for lora in self.unet_loras: params.extend(lora.parameters()) - param_data = {'params': params} + param_data = {"params": params} if unet_lr is not None: - param_data['lr'] = unet_lr + param_data["lr"] = unet_lr all_params.append(param_data) return all_params From c639cb7d5dce38d5ceb2356da4cddab6c32a8263 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Apr 2023 16:18:04 +0900 Subject: [PATCH 14/24] support older type hint --- networks/lora.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index f1a65074..17bd0b38 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,7 +5,7 @@ import math import os -from typing import List +from typing import List, Union import numpy as np import torch import re @@ -247,7 +247,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class LoRANetwork(torch.nn.Module): - # is it possible to apply conv_in and conv_out? + # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] @@ -333,8 +333,8 @@ class LoRANetwork(torch.nn.Module): self.weights_sd = None - self.up_lr_weight: list[float] = None - self.down_lr_weight: list[float] = None + self.up_lr_weight: List[float] = None + self.down_lr_weight: List[float] = None self.mid_lr_weight: float = None self.block_lr = False @@ -445,9 +445,9 @@ class LoRANetwork(torch.nn.Module): # 層別学習率用に層ごとの学習率に対する倍率を定義する def set_block_lr_weight( self, - up_lr_weight: list[float] | str = None, + up_lr_weight: Union[List[float], str] = None, mid_lr_weight: float = None, - down_lr_weight: list[float] | str = None, + down_lr_weight: Union[List[float], str] = None, zero_threshold: float = 0.0, ): # バラメータ未指定時は何もせず、今までと同じ動作とする @@ -456,7 +456,7 @@ class LoRANetwork(torch.nn.Module): max_len = 12 # フルモデル相当でのup,downの層の数 - def get_list(name) -> list[float]: + def get_list(name) -> List[float]: import math if name == "cosine": From 3beddf341e6f14bf402e51200cb12377b4d9910d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 08:43:11 +0900 Subject: [PATCH 15/24] Suppor LR graphs for each block, base lr --- networks/lora.py | 95 ++++++++++++++++++++++++++++++++++-------------- train_network.py | 46 +++++++++++++++-------- 2 files changed, 98 insertions(+), 43 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 17bd0b38..27335efe 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,7 +5,7 @@ import math import os -from typing import List, Union +from typing import List, Tuple, Union import numpy as np import torch import re @@ -247,6 +247,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] @@ -394,7 +396,7 @@ class LoRANetwork(torch.nn.Module): skipped = [] for lora in self.text_encoder_loras + self.unet_loras: - if self.block_lr and self.get_block_lr_weight(lora) == 0: + if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight skipped.append(lora.lora_name) continue lora.apply_to() @@ -450,25 +452,29 @@ class LoRANetwork(torch.nn.Module): down_lr_weight: Union[List[float], str] = None, zero_threshold: float = 0.0, ): - # バラメータ未指定時は何もせず、今までと同じ動作とする + # パラメータ未指定時は何もせず、今までと同じ動作とする if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: return - max_len = 12 # フルモデル相当でのup,downの層の数 + max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 - def get_list(name) -> List[float]: + def get_list(name_with_suffix) -> List[float]: import math + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))] + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)] + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] elif name == "linear": - return [i / (max_len - 1) for i in range(max_len)] + return [i / (max_len - 1) + base_lr for i in range(max_len)] elif name == "reverse_linear": - return [i / (max_len - 1) for i in reversed(range(max_len))] + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] elif name == "zeros": - return [0.0] * max_len + return [0.0 + base_lr] * max_len else: print( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" @@ -520,7 +526,9 @@ class LoRANetwork(torch.nn.Module): return - def get_block_lr_weight(self, lora: LoRAModule) -> float: + def get_block_index(self, lora: LoRAModule) -> int: + block_idx = -1 # invalid lora name + m = RE_UPDOWN.search(lora.lora_name) if m: g = m.groups() @@ -533,43 +541,74 @@ class LoRANetwork(torch.nn.Module): elif g[2] == "upsamplers" or g[2] == "downsamplers": idx = 3 * i + 2 - if (g[0] == "down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx + 1] - if (g[0] == "up") and (self.up_lr_weight != None): - return self.up_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 - return self.mid_lr_weight - return 1 + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora.lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + def get_lr_weight(self, lora: LoRAModule) -> float: + lr_weight = 1.0 + block_idx = self.get_block_index(lora) + if block_idx < 0: + return lr_weight + + if block_idx < LoRANetwork.NUM_OF_BLOCKS: + if self.down_lr_weight != None: + lr_weight = self.down_lr_weight[block_idx] + elif block_idx == LoRANetwork.NUM_OF_BLOCKS: + if self.mid_lr_weight != None: + lr_weight = self.mid_lr_weight + elif block_idx > LoRANetwork.NUM_OF_BLOCKS: + if self.up_lr_weight != None: + lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] + + return lr_weight def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] - if self.text_encoder_loras: + def enumerate_params(loras): params = [] - for lora in self.text_encoder_loras: + for lora in loras: params.extend(lora.parameters()) - param_data = {"params": params} + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} if text_encoder_lr is not None: param_data["lr"] = text_encoder_lr all_params.append(param_data) if self.unet_loras: if self.block_lr: + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + block_idx_to_lora = {} for lora in self.unet_loras: - param_data = {"params": lora.parameters()} + idx = self.get_block_index(lora) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # blockごとにパラメータを設定する + for idx, block_loras in block_idx_to_lora.items(): + param_data = {"params": enumerate_params(block_loras)} + if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_block_lr_weight(lora) + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_block_lr_weight(lora) + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) if ("lr" in param_data) and (param_data["lr"] == 0): continue all_params.append(param_data) + else: - params = [] - for lora in self.unet_loras: - params.extend(lora.parameters()) - param_data = {"params": params} + param_data = {"params": enumerate_params(self.unet_loras)} if unet_lr is not None: param_data["lr"] = unet_lr all_params.append(param_data) diff --git a/train_network.py b/train_network.py index 2b824018..a7b167bf 100644 --- a/train_network.py +++ b/train_network.py @@ -32,16 +32,31 @@ from library.custom_train_functions import apply_snr_weight def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} - if args.network_train_unet_only: - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - else: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder + lrs = lr_scheduler.get_last_lr() - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) + if args.network_train_unet_only: + logs["lr/unet"] = float(lrs[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lrs[0]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/block{i}"] = float(lrs[i]) + if args.optimizer_type.lower() == "DAdaptation".lower(): + logs[f"lr/d*lr/block{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) return logs @@ -99,10 +114,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) + current_epoch = Value("i", 0) + current_step = Value("i", 0) ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) @@ -146,7 +161,6 @@ def train(args): torch.cuda.empty_cache() accelerator.wait_for_everyone() - # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -214,7 +228,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") @@ -518,7 +534,7 @@ def train(args): for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) From 817a9268ff85ce5a01b7f397c340153f3a5b0d24 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 08:43:26 +0900 Subject: [PATCH 16/24] update readme for block weight lr --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index 76695270..7ef0947a 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,30 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +- 3 Apr. 2023, 2023/4/3: + - Add `--network_args` option to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for your great contribution! + - Specify the weights of 25 blocks for the full model. + - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. + - Specify the following arguments with `--network_args`. + - `down_lr_weight` : Specify the learning rate weight of the down blocks of U-Net. The following can be specified. + - The weight for each block: Specify 12 numbers such as `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"`. + - Specify from preset: Specify such as `"down_lr_weight=sine"` (the weights by sine curve). sine, cosine, linear, reverse_linear, zeros can be specified. Also, if you add `+number` such as `"down_lr_weight=cosine+.25"`, the specified number is added (such as 0.25~1.25). + - `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`. + - `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight. + - If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created. + + - 階層別学習率を `train_network.py` で指定できるようにしました。u-haru 氏の多大な貢献に感謝します。 + - フルモデルの25個のブロックの重みを指定できます。 + - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合は一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 + -`--network_args` で以下の引数を指定してください。 + - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 + - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 + - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 + - `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 + - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 + - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 + + - 1 Apr. 2023, 2023/4/1: - Fix an issue that `merge_lora.py` does not work with the latest version. - Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights. From 6134619998bb0f7d319e466f8632c7776440f787 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 21:19:49 +0900 Subject: [PATCH 17/24] Add block dim(rank) feature --- README.md | 64 ++-- gen_img_diffusers.py | 6 +- networks/extract_lora_from_models.py | 4 +- networks/lora.py | 543 ++++++++++++++++----------- 4 files changed, 361 insertions(+), 256 deletions(-) diff --git a/README.md b/README.md index 7ef0947a..0590bfd0 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History -- 3 Apr. 2023, 2023/4/3: - - Add `--network_args` option to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for your great contribution! +- 4 Apr. 2023, 2023/4/4: + - Add options to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for the great contribution! - Specify the weights of 25 blocks for the full model. - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. - Specify the following arguments with `--network_args`. @@ -138,10 +138,19 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`. - `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight. - If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created. + - `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0. - - 階層別学習率を `train_network.py` で指定できるようにしました。u-haru 氏の多大な貢献に感謝します。 + - Add options to `train_network.py` to specify block dims (ranks) for variable rank. + - Specify 25 values ​​for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values ​​always. + - Specify the following arguments with `--network_args`. + - `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`. + - `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used. + - `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. + - `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used. + + - 階層別学習率を `train_network.py` で指定できるようになりました。u-haru 氏の多大な貢献に感謝します。 - フルモデルの25個のブロックの重みを指定できます。 - - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合は一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 + - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 -`--network_args` で以下の引数を指定してください。 - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 @@ -149,33 +158,30 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 - + - `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 -- 1 Apr. 2023, 2023/4/1: - - Fix an issue that `merge_lora.py` does not work with the latest version. - - Fix an issue that `merge_lora.py` does not merge Conv2d3x3 weights. - - 最新のバージョンで`merge_lora.py` が動作しない不具合を修正しました。 - - `merge_lora.py` で `no module found for LoRA weight: ...` と表示され Conv2d3x3 拡張の重みがマージされない不具合を修正しました。 -- 31 Mar. 2023, 2023/3/31: - - Fix an issue that the VRAM usage temporarily increases when loading a model in `train_network.py`. - - Fix an issue that an error occurs when loading a `.safetensors` model in `train_network.py`. [#354](https://github.com/kohya-ss/sd-scripts/issues/354) - - `train_network.py` でモデル読み込み時にVRAM使用量が一時的に大きくなる不具合を修正しました。 - - `train_network.py` で `.safetensors` 形式のモデルを読み込むとエラーになる不具合を修正しました。[#354](https://github.com/kohya-ss/sd-scripts/issues/354) -- 30 Mar. 2023, 2023/3/30: - - Support [P+](https://prompt-plus.github.io/) training. Thank you jakaline-dev! - - See [#327](https://github.com/kohya-ss/sd-scripts/pull/327) for details. - - Use `train_textual_inversion_XTI.py` for training. The usage is almost the same as `train_textual_inversion.py`. However, sample image generation during training is not supported. - - Use `gen_img_diffusers.py` for image generation (I think Web UI is not supported). Specify the embedding with `--XTI_embeddings` option. - - Reduce RAM usage at startup in `train_network.py`. [#332](https://github.com/kohya-ss/sd-scripts/pull/332) Thank you guaneec! - - Support pre-merge for LoRA in `gen_img_diffusers.py`. Specify `--network_merge` option. Note that the `--am` option of the prompt option is no longer available with this option. + - 階層別dim (rank)を `train_network.py` で指定できるようになりました。 + - フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 + - `--network_args` で以下の引数を指定してください。 + - `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 + - `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。 + - `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。 + - `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。 + + - 階層別学習率コマンドライン指定例 / Examples of block learning rate command line specification: + + ` --network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"` + + ` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"` + + - 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification: + + ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"` + + ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` + + ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` - - [P+](https://prompt-plus.github.io/) の学習に対応しました。jakaline-dev氏に感謝します。 - - 詳細は [#327](https://github.com/kohya-ss/sd-scripts/pull/327) をご参照ください。 - - 学習には `train_textual_inversion_XTI.py` を使用します。使用法は `train_textual_inversion.py` とほぼ同じです。た - だし学習中のサンプル生成には対応していません。 - - 画像生成には `gen_img_diffusers.py` を使用してください(Web UIは対応していないと思われます)。`--XTI_embeddings` オプションで学習したembeddingを指定してください。 - - `train_network.py` で起動時のRAM使用量を削減しました。[#332](https://github.com/kohya-ss/sd-scripts/pull/332) guaneec氏に感謝します。 - - `gen_img_diffusers.py` でLoRAの事前マージに対応しました。`--network_merge` オプションを指定してください。なおプロンプトオプションの `--am` は使用できなくなります。 ## Sample image generation during training A prompt file might look like this, for example diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 225de33c..a0469766 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2275,7 +2275,7 @@ def main(args): if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") - network = imported_module.create_network_from_weights( + network, weights_sd = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, **net_kwargs ) else: @@ -2285,6 +2285,8 @@ def main(args): if not args.network_merge: network.apply_to(text_encoder, unet) + info = network.load_state_dict(weights_sd, False) + print(f"weights are loaded: {info}") if args.opt_channels_last: network.to(memory_format=torch.channels_last) @@ -2292,7 +2294,7 @@ def main(args): networks.append(network) else: - network.merge_to(text_encoder, unet, dtype, device) + network.merge_to(text_encoder, unet, weights_sd, dtype, device) else: networks = [] diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 9aa28485..f001e7eb 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -145,8 +145,8 @@ def svd(args): lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) - lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") diff --git a/networks/lora.py b/networks/lora.py index 27335efe..c5372688 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -143,6 +143,8 @@ class LoRAModule(torch.nn.Module): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 # extract dim/alpha for conv2d, and block dim conv_dim = kwargs.get("conv_dim", None) @@ -154,34 +156,50 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un else: conv_alpha = float(conv_alpha) - """ - block_dims = kwargs.get("block_dims") - block_alphas = None + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + down_lr_weight = kwargs.get("down_lr_weight", None) + mid_lr_weight = kwargs.get("mid_lr_weight", None) + up_lr_weight = kwargs.get("up_lr_weight", None) + + # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする + if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + ) - if block_dims is not None: - block_dims = [int(d) for d in block_dims.split(',')] - assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" - block_alphas = kwargs.get("block_alphas") - if block_alphas is None: - block_alphas = [1] * len(block_dims) else: - block_alphas = [int(a) for a in block_alphas(',')] - assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - - conv_block_dims = kwargs.get("conv_block_dims") - conv_block_alphas = None - - if conv_block_dims is not None: - conv_block_dims = [int(d) for d in conv_block_dims.split(',')] - assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" - conv_block_alphas = kwargs.get("conv_block_alphas") - if conv_block_alphas is None: - conv_block_alphas = [1] * len(conv_block_dims) - else: - conv_block_alphas = [int(a) for a in conv_block_alphas(',')] - assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - """ + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoder, unet, @@ -190,28 +208,219 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, ) - # if some parameters are not set, use zero - up_lr_weight = kwargs.get("up_lr_weight", None) - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight = kwargs.get("down_lr_weight", None) - if down_lr_weight is not None: - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - mid_lr_weight = kwargs.get("mid_lr_weight", None) - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) return network +# このメソッドは外部から呼び出される可能性を考慮しておく +# network_dim, network_alpha にはデフォルト値が入っている。 +# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている +# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている +def get_block_dims_and_alphas( + block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + # block_dimsとblock_alphasをパースする。必ず値が入る + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert ( + len(block_dims) == num_total_blocks + ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + else: + print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" + else: + print( + f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" + else: + if conv_alpha is None: + conv_alpha = 1.0 + print( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + print( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +def get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold +) -> Tuple[List[float], List[float], List[float]]: + # パラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None, None, None + + max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + elif name == "linear": + return [i / (max_len - 1) + base_lr for i in range(max_len)] + elif name == "reverse_linear": + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + elif name == "zeros": + return [0.0 + base_lr] * max_len + else: + print( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): + print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) + up_lr_weight = up_lr_weight[:max_len] + down_lr_weight = down_lr_weight[:max_len] + + if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): + print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) + print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + + if down_lr_weight != None and len(down_lr_weight) < max_len: + down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len: + up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + print("apply block learning rate / 階層別学習率を適用します。") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) + else: + print("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: + mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:", mid_lr_weight) + else: + print("mid_lr_weight: 1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) + else: + print("up_lr_weight: all 1.0, すべて1.0") + + return down_lr_weight, mid_lr_weight, up_lr_weight + + +# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく +def remove_block_dims_and_alphas( + block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight +): + # set 0 to block dim without learning rate to remove the block + if down_lr_weight != None: + for i, lr in enumerate(down_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + if mid_lr_weight != None: + if mid_lr_weight == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 + if up_lr_weight != None: + for i, lr in enumerate(up_lr_weight): + if lr == 0: + block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + if conv_block_dims is not None: + conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str) -> int: + block_idx = -1 # invalid lora name + + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + +# Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": @@ -242,8 +451,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh modules_alpha = modules_dim[key] network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) - network.weights_sd = weights_sd - return network + return network, weights_sd class LoRANetwork(torch.nn.Module): @@ -265,9 +473,22 @@ class LoRANetwork(torch.nn.Module): alpha=1, conv_lora_dim=None, conv_alpha=None, + block_dims=None, + block_alphas=None, + conv_block_dims=None, + conv_block_alphas=None, modules_dim=None, modules_alpha=None, + varbose=False, ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ super().__init__() self.multiplier = multiplier @@ -278,62 +499,83 @@ class LoRANetwork(torch.nn.Module): if modules_dim is not None: print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") else: print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - - self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None - if self.apply_to_conv2d_3x3: - if self.conv_alpha is None: - self.conv_alpha = self.alpha - print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + if self.conv_lora_dim is not None: + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances - def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: + prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER loras = [] + skipped = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: - # TODO get block index here for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + if is_linear or is_conv2d: lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + dim = None + alpha = None if modules_dim is not None: - if lora_name not in modules_dim: - continue # no LoRA module in this weights file - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + block_idx = get_block_index(lora_name) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] else: if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha - elif self.apply_to_conv2d_3x3: + elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha - else: - continue + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) loras.append(lora) - return loras + return loras, skipped - self.text_encoder_loras = create_modules( - LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - ) + self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE - if modules_dim is not None or self.conv_lora_dim is not None: + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) + self.unet_loras, skipped_un = create_modules(True, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - self.weights_sd = None + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + print( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + print(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None @@ -351,39 +593,7 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier - def load_weights(self, file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file, safe_open - - self.weights_sd = load_file(file) - else: - self.weights_sd = torch.load(file, map_location="cpu") - - def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): - if self.weights_sd: - weights_has_text_encoder = weights_has_unet = False - for key in self.weights_sd.keys(): - if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): - weights_has_text_encoder = True - elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): - weights_has_unet = True - - if apply_text_encoder is None: - apply_text_encoder = weights_has_text_encoder - else: - assert ( - apply_text_encoder == weights_has_text_encoder - ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" - - if apply_unet is None: - apply_unet = weights_has_unet - else: - assert ( - apply_unet == weights_has_unet - ), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" - else: - assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" - + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: print("enable LoRA for text encoder") else: @@ -394,32 +604,14 @@ class LoRANetwork(torch.nn.Module): else: self.unet_loras = [] - skipped = [] for lora in self.text_encoder_loras + self.unet_loras: - if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight - skipped.append(lora.lora_name) - continue lora.apply_to() self.add_module(lora.lora_name, lora) - if len(skipped) > 0: - print( - f"because block_lr_weight is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" - ) - for name in skipped: - print(f"\t{name}") - - if self.weights_sd: - # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) - info = self.load_state_dict(self.weights_sd, False) - print(f"weights are loaded: {info}") - # TODO refactor to common function with apply_to - def merge_to(self, text_encoder, unet, dtype, device): - assert self.weights_sd is not None, "weights are not loaded" - + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): apply_text_encoder = apply_unet = False - for key in self.weights_sd.keys(): + for key in weights_sd.keys(): if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): apply_text_encoder = True elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): @@ -437,9 +629,9 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: sd_for_lora = {} - for key in self.weights_sd.keys(): + for key in weights_sd.keys(): if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1 :]] = self.weights_sd[key] + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) print(f"weights are merged") @@ -447,113 +639,18 @@ class LoRANetwork(torch.nn.Module): # 層別学習率用に層ごとの学習率に対する倍率を定義する def set_block_lr_weight( self, - up_lr_weight: Union[List[float], str] = None, + up_lr_weight: List[float] = None, mid_lr_weight: float = None, - down_lr_weight: Union[List[float], str] = None, - zero_threshold: float = 0.0, + down_lr_weight: List[float] = None, ): - # パラメータ未指定時は何もせず、今までと同じ動作とする - if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return - - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 - - def get_list(name_with_suffix) -> List[float]: - import math - - tokens = name_with_suffix.split("+") - name = tokens[0] - base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 - - if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] - elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] - elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] - elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] - elif name == "zeros": - return [0.0 + base_lr] * max_len - else: - print( - "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" - % (name) - ) - return None - - if type(down_lr_weight) == str: - down_lr_weight = get_list(down_lr_weight) - if type(up_lr_weight) == str: - up_lr_weight = get_list(up_lr_weight) - - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] - - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) - - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) - - if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): - print("apply block learning rate / 階層別学習率を適用します。") - self.block_lr = True - - if down_lr_weight != None: - self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] - print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", self.down_lr_weight) - else: - print("down_lr_weight: all 1.0, すべて1.0") - - if mid_lr_weight != None: - self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 - print("mid_lr_weight:", self.mid_lr_weight) - else: - print("mid_lr_weight: 1.0") - - if up_lr_weight != None: - self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] - print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", self.up_lr_weight) - else: - print("up_lr_weight: all 1.0, すべて1.0") - - return - - def get_block_index(self, lora: LoRAModule) -> int: - block_idx = -1 # invalid lora name - - m = RE_UPDOWN.search(lora.lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora.lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 - - return block_idx + self.block_lr = True + self.down_lr_weight = down_lr_weight + self.mid_lr_weight = mid_lr_weight + self.up_lr_weight = up_lr_weight def get_lr_weight(self, lora: LoRAModule) -> float: lr_weight = 1.0 - block_idx = self.get_block_index(lora) + block_idx = get_block_index(lora.lora_name) if block_idx < 0: return lr_weight @@ -590,7 +687,7 @@ class LoRANetwork(torch.nn.Module): # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 block_idx_to_lora = {} for lora in self.unet_loras: - idx = self.get_block_index(lora) + idx = get_block_index(lora.lora_name) if idx not in block_idx_to_lora: block_idx_to_lora[idx] = [] block_idx_to_lora[idx].append(lora) From 53cc3583df729dc69349b687cb52caed786ff3b2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 21:46:12 +0900 Subject: [PATCH 18/24] fix potential issue with dtype --- fine_tune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index 637a729a..50549878 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -275,7 +275,7 @@ def train(args): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() From 7209eb74ccaa21825001bdca9c3a57ce7aae19d1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 22:08:58 +0900 Subject: [PATCH 19/24] update readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 0590bfd0..52fecd89 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History - 4 Apr. 2023, 2023/4/4: + - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. + + - Fix some bugs and add some features. + - Fix an issue that `.json` format dataset config files cannot be read. [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) Thanks to rockerBOO! + - Raise an error when an invalid `--lr_warmup_steps` option is specified (when warmup is not valid for the specified scheduler). [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) Thanks to shirayu! + - Fix the data type handling in `fine_tune.py`. This may fix an error that occurs in some environments when using xformers, npz format cache, and mixed_precision. + - Add options to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for the great contribution! - Specify the weights of 25 blocks for the full model. - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. @@ -148,6 +155,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. - `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used. + - 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。 + + - いくつかのバグ修正、機能追加を行いました。 + - `.json`形式のdataset設定ファイルを読み込めない不具合を修正しました。 [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) rockerBOO 氏に感謝します。 + - 無効な`--lr_warmup_steps` オプション(指定したスケジューラでwarmupが無効な場合)を指定している場合にエラーを出すようにしました。 [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) shirayu 氏に感謝します。 + - `fine_tune.py` でデータ型の取り扱いが誤っていたのを修正しました。一部の環境でxformersを使い、npz形式のキャッシュ、mixed_precisionで学習した時にエラーとなる不具合が解消されるかもしれません。 + - 階層別学習率を `train_network.py` で指定できるようになりました。u-haru 氏の多大な貢献に感謝します。 - フルモデルの25個のブロックの重みを指定できます。 - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 From 83c7e03d050fc25f47a591c4ddfe28abdabc7ae7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 22:45:28 +0900 Subject: [PATCH 20/24] Fix network_weights not working in train_network --- gen_img_diffusers.py | 2 +- networks/lora.py | 11 +++++++++++ train_network.py | 8 ++++---- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a0469766..af83ce47 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2285,7 +2285,7 @@ def main(args): if not args.network_merge: network.apply_to(text_encoder, unet) - info = network.load_state_dict(weights_sd, False) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") if args.opt_channels_last: diff --git a/networks/lora.py b/networks/lora.py index c5372688..4e0573d0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -593,6 +593,17 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: print("enable LoRA for text encoder") diff --git a/train_network.py b/train_network.py index a7b167bf..c79b0922 100644 --- a/train_network.py +++ b/train_network.py @@ -194,14 +194,14 @@ def train(args): if network is None: return - if args.network_weights is not None: - print("load network weights from:", args.network_weights) - network.load_weights(args.network_weights) - train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + print(f"load network weights from {args.network_weights}: {info}") + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() From 626d4b433a64b030cfe67df547467939b2b572c3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 3 Apr 2023 12:38:20 -0400 Subject: [PATCH 21/24] Add min_snr_gamma to metadata --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 476f76df..da91ec4f 100644 --- a/train_network.py +++ b/train_network.py @@ -346,6 +346,7 @@ def train(args): "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, "ss_face_crop_aug_range": args.face_crop_aug_range, "ss_prior_loss_weight": args.prior_loss_weight, + "ss_min_snr_gamma": args.min_snr_gamma, } if use_user_config: From e4eb3e63e67038897840ebb8c2c4d781f8cfde60 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Apr 2023 07:48:48 +0900 Subject: [PATCH 22/24] improve compatibility --- train_network.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index c79b0922..07bf44bb 100644 --- a/train_network.py +++ b/train_network.py @@ -52,9 +52,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche idx = 1 for i in range(idx, len(lrs)): - logs[f"lr/block{i}"] = float(lrs[i]) + logs[f"lr/group{i}"] = float(lrs[i]) if args.optimizer_type.lower() == "DAdaptation".lower(): - logs[f"lr/d*lr/block{i}"] = ( + logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) @@ -193,6 +193,9 @@ def train(args): network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return + + if hasattr(network, "prepare_network"): + network.prepare_network(args) train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only @@ -490,8 +493,6 @@ def train(args): # add extra args if args.network_args: metadata["ss_network_args"] = json.dumps(net_kwargs) - # for key, value in net_kwargs.items(): - # metadata["ss_arg_" + key] = value # model name and hash if args.pretrained_model_name_or_path is not None: From 76bac2c1c59a58d94db4cc803a074560119cc307 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Apr 2023 08:27:11 +0900 Subject: [PATCH 23/24] add backward compatiblity --- train_network.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 0fd8cc0f..9956a905 100644 --- a/train_network.py +++ b/train_network.py @@ -213,7 +213,13 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + # 後方互換性を確保するよ + try: + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + except TypeError: + print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)") + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する From 4b47e8ecb0b58b2a9dc3ca7050c2c0f2b2641092 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Apr 2023 08:27:30 +0900 Subject: [PATCH 24/24] update readme --- README.md | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 52fecd89..567ba6a6 100644 --- a/README.md +++ b/README.md @@ -129,13 +129,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - 4 Apr. 2023, 2023/4/4: - There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while. + - The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed. - Fix some bugs and add some features. - Fix an issue that `.json` format dataset config files cannot be read. [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) Thanks to rockerBOO! - Raise an error when an invalid `--lr_warmup_steps` option is specified (when warmup is not valid for the specified scheduler). [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) Thanks to shirayu! + - Add `min_snr_gamma` to metadata in `train_network.py`. [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) Thanks to rockerBOO! - Fix the data type handling in `fine_tune.py`. This may fix an error that occurs in some environments when using xformers, npz format cache, and mixed_precision. - - Add options to `train_network.py` to specify block weights for learning rates. Thanks to u-haru for the great contribution! + - Add options to `train_network.py` to specify block weights for learning rates. [PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) Thanks to u-haru for the great contribution! - Specify the weights of 25 blocks for the full model. - No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values ​​for the argument for consistency. - Specify the following arguments with `--network_args`. @@ -156,13 +158,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used. - 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。 - + - 階層別学習率、階層別dim(rank)についてはモジュール側の変更が必要なため、当リポジトリ内のnetworkモジュール以外(LyCORISなど)では現在は動作しないと思われます。 + - いくつかのバグ修正、機能追加を行いました。 - `.json`形式のdataset設定ファイルを読み込めない不具合を修正しました。 [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) rockerBOO 氏に感謝します。 - 無効な`--lr_warmup_steps` オプション(指定したスケジューラでwarmupが無効な場合)を指定している場合にエラーを出すようにしました。 [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) shirayu 氏に感謝します。 + - `train_network.py` で `min_snr_gamma` をメタデータに追加しました。 [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) rockerBOO 氏に感謝します。 - `fine_tune.py` でデータ型の取り扱いが誤っていたのを修正しました。一部の環境でxformersを使い、npz形式のキャッシュ、mixed_precisionで学習した時にエラーとなる不具合が解消されるかもしれません。 - - - 階層別学習率を `train_network.py` で指定できるようになりました。u-haru 氏の多大な貢献に感謝します。 + + - 階層別学習率を `train_network.py` で指定できるようになりました。[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) u-haru 氏の多大な貢献に感謝します。 - フルモデルの25個のブロックの重みを指定できます。 - 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 -`--network_args` で以下の引数を指定してください。 @@ -188,6 +192,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"` + - 階層別学習率tomlファイル指定例 / Examples of block learning rate toml file specification + + `network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]` + + `network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]` + + - 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification: ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"` @@ -196,6 +207,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` + - 階層別dim (rank)tomlファイル指定例 / Examples of block dim (rank) toml file specification + + `network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]` + + `network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]` + ## Sample image generation during training A prompt file might look like this, for example