Suppor LR graphs for each block, base lr

This commit is contained in:
Kohya S
2023-04-03 08:43:11 +09:00
parent c639cb7d5d
commit 3beddf341e
2 changed files with 98 additions and 43 deletions

View File

@@ -5,7 +5,7 @@
import math
import os
from typing import List, Union
from typing import List, Tuple, Union
import numpy as np
import torch
import re
@@ -247,6 +247,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
@@ -394,7 +396,7 @@ class LoRANetwork(torch.nn.Module):
skipped = []
for lora in self.text_encoder_loras + self.unet_loras:
if self.block_lr and self.get_block_lr_weight(lora) == 0:
if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight
skipped.append(lora.lora_name)
continue
lora.apply_to()
@@ -450,25 +452,29 @@ class LoRANetwork(torch.nn.Module):
down_lr_weight: Union[List[float], str] = None,
zero_threshold: float = 0.0,
):
# ラメータ未指定時は何もせず、今までと同じ動作とする
# ラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return
max_len = 12 # フルモデル相当でのup,downの層の数
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
def get_list(name) -> List[float]:
def get_list(name_with_suffix) -> List[float]:
import math
tokens = name_with_suffix.split("+")
name = tokens[0]
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))]
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)]
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
elif name == "linear":
return [i / (max_len - 1) for i in range(max_len)]
return [i / (max_len - 1) + base_lr for i in range(max_len)]
elif name == "reverse_linear":
return [i / (max_len - 1) for i in reversed(range(max_len))]
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
elif name == "zeros":
return [0.0] * max_len
return [0.0 + base_lr] * max_len
else:
print(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
@@ -520,7 +526,9 @@ class LoRANetwork(torch.nn.Module):
return
def get_block_lr_weight(self, lora: LoRAModule) -> float:
def get_block_index(self, lora: LoRAModule) -> int:
block_idx = -1 # invalid lora name
m = RE_UPDOWN.search(lora.lora_name)
if m:
g = m.groups()
@@ -533,43 +541,74 @@ class LoRANetwork(torch.nn.Module):
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if (g[0] == "down") and (self.down_lr_weight != None):
return self.down_lr_weight[idx + 1]
if (g[0] == "up") and (self.up_lr_weight != None):
return self.up_lr_weight[idx]
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
return self.mid_lr_weight
return 1
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora.lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
return block_idx
def get_lr_weight(self, lora: LoRAModule) -> float:
lr_weight = 1.0
block_idx = self.get_block_index(lora)
if block_idx < 0:
return lr_weight
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
if self.down_lr_weight != None:
lr_weight = self.down_lr_weight[block_idx]
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
if self.mid_lr_weight != None:
lr_weight = self.mid_lr_weight
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
if self.up_lr_weight != None:
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
return lr_weight
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
if self.text_encoder_loras:
def enumerate_params(loras):
params = []
for lora in self.text_encoder_loras:
for lora in loras:
params.extend(lora.parameters())
param_data = {"params": params}
return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
if self.unet_loras:
if self.block_lr:
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras:
param_data = {"params": lora.parameters()}
idx = self.get_block_index(lora)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_block_lr_weight(lora)
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_block_lr_weight(lora)
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
else:
params = []
for lora in self.unet_loras:
params.extend(lora.parameters())
param_data = {"params": params}
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)