Suppor LR graphs for each block, base lr

This commit is contained in:
Kohya S
2023-04-03 08:43:11 +09:00
parent c639cb7d5d
commit 3beddf341e
2 changed files with 98 additions and 43 deletions

View File

@@ -5,7 +5,7 @@
import math import math
import os import os
from typing import List, Union from typing import List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import re import re
@@ -247,6 +247,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
@@ -394,7 +396,7 @@ class LoRANetwork(torch.nn.Module):
skipped = [] skipped = []
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
if self.block_lr and self.get_block_lr_weight(lora) == 0: if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight
skipped.append(lora.lora_name) skipped.append(lora.lora_name)
continue continue
lora.apply_to() lora.apply_to()
@@ -450,25 +452,29 @@ class LoRANetwork(torch.nn.Module):
down_lr_weight: Union[List[float], str] = None, down_lr_weight: Union[List[float], str] = None,
zero_threshold: float = 0.0, zero_threshold: float = 0.0,
): ):
# ラメータ未指定時は何もせず、今までと同じ動作とする # ラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return return
max_len = 12 # フルモデル相当でのup,downの層の数 max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
def get_list(name) -> List[float]: def get_list(name_with_suffix) -> List[float]:
import math import math
tokens = name_with_suffix.split("+")
name = tokens[0]
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine": if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))] return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
elif name == "sine": elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)] return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
elif name == "linear": elif name == "linear":
return [i / (max_len - 1) for i in range(max_len)] return [i / (max_len - 1) + base_lr for i in range(max_len)]
elif name == "reverse_linear": elif name == "reverse_linear":
return [i / (max_len - 1) for i in reversed(range(max_len))] return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
elif name == "zeros": elif name == "zeros":
return [0.0] * max_len return [0.0 + base_lr] * max_len
else: else:
print( print(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
@@ -520,7 +526,9 @@ class LoRANetwork(torch.nn.Module):
return return
def get_block_lr_weight(self, lora: LoRAModule) -> float: def get_block_index(self, lora: LoRAModule) -> int:
block_idx = -1 # invalid lora name
m = RE_UPDOWN.search(lora.lora_name) m = RE_UPDOWN.search(lora.lora_name)
if m: if m:
g = m.groups() g = m.groups()
@@ -533,43 +541,74 @@ class LoRANetwork(torch.nn.Module):
elif g[2] == "upsamplers" or g[2] == "downsamplers": elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2 idx = 3 * i + 2
if (g[0] == "down") and (self.down_lr_weight != None): if g[0] == "down":
return self.down_lr_weight[idx + 1] block_idx = 1 + idx # 0に該当するLoRAは存在しない
if (g[0] == "up") and (self.up_lr_weight != None): elif g[0] == "up":
return self.up_lr_weight[idx] block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
return self.mid_lr_weight elif "mid_block_" in lora.lora_name:
return 1 block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
return block_idx
def get_lr_weight(self, lora: LoRAModule) -> float:
lr_weight = 1.0
block_idx = self.get_block_index(lora)
if block_idx < 0:
return lr_weight
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
if self.down_lr_weight != None:
lr_weight = self.down_lr_weight[block_idx]
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
if self.mid_lr_weight != None:
lr_weight = self.mid_lr_weight
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
if self.up_lr_weight != None:
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
return lr_weight
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True) self.requires_grad_(True)
all_params = [] all_params = []
if self.text_encoder_loras: def enumerate_params(loras):
params = [] params = []
for lora in self.text_encoder_loras: for lora in loras:
params.extend(lora.parameters()) params.extend(lora.parameters())
param_data = {"params": params} return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None: if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr param_data["lr"] = text_encoder_lr
all_params.append(param_data) all_params.append(param_data)
if self.unet_loras: if self.unet_loras:
if self.block_lr: if self.block_lr:
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras: for lora in self.unet_loras:
param_data = {"params": lora.parameters()} idx = self.get_block_index(lora)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}
if unet_lr is not None: if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_block_lr_weight(lora) param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None: elif default_lr is not None:
param_data["lr"] = default_lr * self.get_block_lr_weight(lora) param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0): if ("lr" in param_data) and (param_data["lr"] == 0):
continue continue
all_params.append(param_data) all_params.append(param_data)
else: else:
params = [] param_data = {"params": enumerate_params(self.unet_loras)}
for lora in self.unet_loras:
params.extend(lora.parameters())
param_data = {"params": params}
if unet_lr is not None: if unet_lr is not None:
param_data["lr"] = unet_lr param_data["lr"] = unet_lr
all_params.append(param_data) all_params.append(param_data)

View File

@@ -32,16 +32,31 @@ from library.custom_train_functions import apply_snr_weight
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss} logs = {"loss/current": current_loss, "loss/average": avr_loss}
lrs = lr_scheduler.get_last_lr()
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
if args.network_train_unet_only: if args.network_train_unet_only:
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) logs["lr/unet"] = float(lrs[0])
elif args.network_train_text_encoder_only: elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) logs["lr/textencoder"] = float(lrs[0])
else: else:
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
logs[f"lr/block{i}"] = float(lrs[i])
if args.optimizer_type.lower() == "DAdaptation".lower():
logs[f"lr/d*lr/block{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs return logs
@@ -99,8 +114,8 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0) current_epoch = Value("i", 0)
current_step = Value('i',0) current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
@@ -146,7 +161,6 @@ def train(args):
torch.cuda.empty_cache() torch.cuda.empty_cache()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -214,7 +228,9 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")