mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
implement stratified_lr
This commit is contained in:
141
networks/lora.py
141
networks/lora.py
@@ -8,9 +8,11 @@ import os
|
|||||||
from typing import List
|
from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
|
|
||||||
from library import train_util
|
from library import train_util
|
||||||
|
|
||||||
|
RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_')
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
else:
|
else:
|
||||||
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
|
||||||
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
@@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
|||||||
conv_lora_dim=conv_dim,
|
conv_lora_dim=conv_dim,
|
||||||
conv_alpha=conv_alpha,
|
conv_alpha=conv_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
up_weight=None
|
||||||
|
if 'up_weight' in kwargs:
|
||||||
|
up_weight = kwargs.get('up_weight',None)
|
||||||
|
if "," in up_weight:
|
||||||
|
up_weight = [float(s) for s in up_weight.split(",") if s]
|
||||||
|
down_weight=None
|
||||||
|
if 'down_weight' in kwargs:
|
||||||
|
down_weight = kwargs.get('down_weight',None)
|
||||||
|
if "," in down_weight:
|
||||||
|
down_weight = [float(s) for s in down_weight.split(",") if s]
|
||||||
|
|
||||||
|
network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('lr_weight_threshold', 0.0)))
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@@ -318,6 +334,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
self.up_weight:list[float] = None
|
||||||
|
self.down_weight:list[float] = None
|
||||||
|
self.mid_weight:float = None
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
@@ -366,9 +386,17 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
|
skipped = []
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
if self.get_stratified_lr_weight(lora) == 0:
|
||||||
|
skipped.append(lora.lora_name)
|
||||||
|
continue
|
||||||
lora.apply_to()
|
lora.apply_to()
|
||||||
self.add_module(lora.lora_name, lora)
|
self.add_module(lora.lora_name, lora)
|
||||||
|
if len(skipped)>0:
|
||||||
|
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
|
||||||
|
for name in skipped:
|
||||||
|
print(f"\t{name}")
|
||||||
|
|
||||||
if self.weights_sd:
|
if self.weights_sd:
|
||||||
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
||||||
@@ -404,34 +432,113 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora.merge_to(sd_for_lora, dtype, device)
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
print(f"weights are merged")
|
print(f"weights are merged")
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||||
# not supported
|
def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0):
|
||||||
pass
|
max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義
|
||||||
|
if self.apply_to_conv2d_3x3:
|
||||||
|
max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
def get_list(name) -> list[float]:
|
||||||
def enumerate_params(loras):
|
import math
|
||||||
params = []
|
if name=="cosine":
|
||||||
for lora in loras:
|
return [math.cos(math.pi*(i/(max_len-1))/2) for i in range(max_len)]
|
||||||
params.extend(lora.parameters())
|
elif name=="sine":
|
||||||
return params
|
return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)]
|
||||||
|
elif name=="linear":
|
||||||
|
return [i/(max_len-1) for i in range(max_len)]
|
||||||
|
elif name=="reverse_linear":
|
||||||
|
return [i/(max_len-1) for i in reversed(range(max_len))]
|
||||||
|
elif name=="zeros":
|
||||||
|
return [0.0] * max_len
|
||||||
|
else:
|
||||||
|
print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
|
||||||
|
return None
|
||||||
|
|
||||||
|
if type(down_weight)==str:
|
||||||
|
down_weight=get_list(down_weight)
|
||||||
|
if type(up_weight)==str:
|
||||||
|
up_weight=get_list(up_weight)
|
||||||
|
|
||||||
|
if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len):
|
||||||
|
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len)
|
||||||
|
if (up_weight != None and len(up_weight)<max_len) or (down_weight != None and len(down_weight)<max_len):
|
||||||
|
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"%max_len)
|
||||||
|
if down_weight != None and len(down_weight)<max_len:
|
||||||
|
down_weight = down_weight + [1.0] * (max_len - len(down_weight))
|
||||||
|
if up_weight != None and len(up_weight)<max_len:
|
||||||
|
up_weight = up_weight + [1.0] * (max_len - len(up_weight))
|
||||||
|
if (up_weight != None) or (mid_weight != None) or (down_weight != None):
|
||||||
|
print("層別学習率を適用します。")
|
||||||
|
if (down_weight != None):
|
||||||
|
self.down_weight = [w if w > zero_threshold else 0 for w in down_weight[:max_len]]
|
||||||
|
print("down_weight(浅い層->深い層):",self.down_weight)
|
||||||
|
if (mid_weight != None):
|
||||||
|
self.mid_weight = mid_weight if mid_weight > zero_threshold else 0
|
||||||
|
print("mid_weight:",self.mid_weight)
|
||||||
|
if (up_weight != None):
|
||||||
|
self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]]
|
||||||
|
print("up_weight(深い層->浅い層):",self.up_weight)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
|
||||||
|
m = RE_UPDOWN.search(lora.lora_name)
|
||||||
|
if m:
|
||||||
|
idx = 0
|
||||||
|
g = m.groups()
|
||||||
|
i = int(g[1])
|
||||||
|
if self.apply_to_conv2d_3x3:
|
||||||
|
if g[2]=="resnets":
|
||||||
|
idx=3*i
|
||||||
|
elif g[2]=="attentions":
|
||||||
|
if g[0]=="down":
|
||||||
|
idx=3*i + 2
|
||||||
|
else:
|
||||||
|
idx=3*i - 1
|
||||||
|
elif g[2]=="upsamplers" or g[2]=="downsamplers":
|
||||||
|
idx=3*i + 1
|
||||||
|
else:
|
||||||
|
idx=i
|
||||||
|
if g[0]=="up":
|
||||||
|
idx=i-1
|
||||||
|
|
||||||
|
if (g[0]=="up") and (self.up_weight != None):
|
||||||
|
return self.up_weight[idx]
|
||||||
|
elif (g[0]=="down") and (self.down_weight != None):
|
||||||
|
return self.down_weight[idx]
|
||||||
|
elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None):
|
||||||
|
return self.mid_weight
|
||||||
|
# print({'params': lora.parameters(), 'lr':alpha*lr})
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
params = []
|
||||||
|
for lora in self.text_encoder_loras:
|
||||||
|
params.extend(lora.parameters())
|
||||||
|
param_data = {'params': params}
|
||||||
if text_encoder_lr is not None:
|
if text_encoder_lr is not None:
|
||||||
param_data["lr"] = text_encoder_lr
|
param_data['lr'] = text_encoder_lr
|
||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
for lora in self.unet_loras:
|
||||||
if unet_lr is not None:
|
param_data={}
|
||||||
param_data["lr"] = unet_lr
|
if unet_lr is not None:
|
||||||
all_params.append(param_data)
|
param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr}
|
||||||
|
elif default_lr is not None:
|
||||||
|
param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr}
|
||||||
|
if param_data["lr"]==0:
|
||||||
|
continue
|
||||||
|
all_params.append(param_data)
|
||||||
return all_params
|
return all_params
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
# not supported
|
||||||
|
pass
|
||||||
|
|
||||||
def prepare_grad_etc(self, text_encoder, unet):
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ def train(args):
|
|||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
|
|||||||
Reference in New Issue
Block a user