From 3beddf341e6f14bf402e51200cb12377b4d9910d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 3 Apr 2023 08:43:11 +0900 Subject: [PATCH] Suppor LR graphs for each block, base lr --- networks/lora.py | 95 ++++++++++++++++++++++++++++++++++-------------- train_network.py | 46 +++++++++++++++-------- 2 files changed, 98 insertions(+), 43 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 17bd0b38..27335efe 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,7 +5,7 @@ import math import os -from typing import List, Union +from typing import List, Tuple, Union import numpy as np import torch import re @@ -247,6 +247,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] @@ -394,7 +396,7 @@ class LoRANetwork(torch.nn.Module): skipped = [] for lora in self.text_encoder_loras + self.unet_loras: - if self.block_lr and self.get_block_lr_weight(lora) == 0: + if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight skipped.append(lora.lora_name) continue lora.apply_to() @@ -450,25 +452,29 @@ class LoRANetwork(torch.nn.Module): down_lr_weight: Union[List[float], str] = None, zero_threshold: float = 0.0, ): - # バラメータ未指定時は何もせず、今までと同じ動作とする + # パラメータ未指定時は何もせず、今までと同じ動作とする if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: return - max_len = 12 # フルモデル相当でのup,downの層の数 + max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 - def get_list(name) -> List[float]: + def get_list(name_with_suffix) -> List[float]: import math + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))] + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)] + return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] elif name == "linear": - return [i / (max_len - 1) for i in range(max_len)] + return [i / (max_len - 1) + base_lr for i in range(max_len)] elif name == "reverse_linear": - return [i / (max_len - 1) for i in reversed(range(max_len))] + return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] elif name == "zeros": - return [0.0] * max_len + return [0.0 + base_lr] * max_len else: print( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" @@ -520,7 +526,9 @@ class LoRANetwork(torch.nn.Module): return - def get_block_lr_weight(self, lora: LoRAModule) -> float: + def get_block_index(self, lora: LoRAModule) -> int: + block_idx = -1 # invalid lora name + m = RE_UPDOWN.search(lora.lora_name) if m: g = m.groups() @@ -533,43 +541,74 @@ class LoRANetwork(torch.nn.Module): elif g[2] == "upsamplers" or g[2] == "downsamplers": idx = 3 * i + 2 - if (g[0] == "down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx + 1] - if (g[0] == "up") and (self.up_lr_weight != None): - return self.up_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 - return self.mid_lr_weight - return 1 + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + + elif "mid_block_" in lora.lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + + return block_idx + + def get_lr_weight(self, lora: LoRAModule) -> float: + lr_weight = 1.0 + block_idx = self.get_block_index(lora) + if block_idx < 0: + return lr_weight + + if block_idx < LoRANetwork.NUM_OF_BLOCKS: + if self.down_lr_weight != None: + lr_weight = self.down_lr_weight[block_idx] + elif block_idx == LoRANetwork.NUM_OF_BLOCKS: + if self.mid_lr_weight != None: + lr_weight = self.mid_lr_weight + elif block_idx > LoRANetwork.NUM_OF_BLOCKS: + if self.up_lr_weight != None: + lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] + + return lr_weight def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] - if self.text_encoder_loras: + def enumerate_params(loras): params = [] - for lora in self.text_encoder_loras: + for lora in loras: params.extend(lora.parameters()) - param_data = {"params": params} + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} if text_encoder_lr is not None: param_data["lr"] = text_encoder_lr all_params.append(param_data) if self.unet_loras: if self.block_lr: + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + block_idx_to_lora = {} for lora in self.unet_loras: - param_data = {"params": lora.parameters()} + idx = self.get_block_index(lora) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # blockごとにパラメータを設定する + for idx, block_loras in block_idx_to_lora.items(): + param_data = {"params": enumerate_params(block_loras)} + if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_block_lr_weight(lora) + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_block_lr_weight(lora) + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) if ("lr" in param_data) and (param_data["lr"] == 0): continue all_params.append(param_data) + else: - params = [] - for lora in self.unet_loras: - params.extend(lora.parameters()) - param_data = {"params": params} + param_data = {"params": enumerate_params(self.unet_loras)} if unet_lr is not None: param_data["lr"] = unet_lr all_params.append(param_data) diff --git a/train_network.py b/train_network.py index 2b824018..a7b167bf 100644 --- a/train_network.py +++ b/train_network.py @@ -32,16 +32,31 @@ from library.custom_train_functions import apply_snr_weight def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} - if args.network_train_unet_only: - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - else: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder + lrs = lr_scheduler.get_last_lr() - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) + if args.network_train_unet_only: + logs["lr/unet"] = float(lrs[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lrs[0]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/block{i}"] = float(lrs[i]) + if args.optimizer_type.lower() == "DAdaptation".lower(): + logs[f"lr/d*lr/block{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) return logs @@ -99,10 +114,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) + current_epoch = Value("i", 0) + current_step = Value("i", 0) ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) @@ -146,7 +161,6 @@ def train(args): torch.cuda.empty_cache() accelerator.wait_for_everyone() - # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -214,7 +228,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") @@ -518,7 +534,7 @@ def train(args): for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1)