mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Suppor LR graphs for each block, base lr
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import List, Union
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
@@ -247,6 +247,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
|
||||
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
@@ -394,7 +396,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
skipped = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if self.block_lr and self.get_block_lr_weight(lora) == 0:
|
||||
if self.block_lr and self.get_lr_weight(lora) == 0: # no LR weight
|
||||
skipped.append(lora.lora_name)
|
||||
continue
|
||||
lora.apply_to()
|
||||
@@ -450,25 +452,29 @@ class LoRANetwork(torch.nn.Module):
|
||||
down_lr_weight: Union[List[float], str] = None,
|
||||
zero_threshold: float = 0.0,
|
||||
):
|
||||
# バラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
||||
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
||||
return
|
||||
|
||||
max_len = 12 # フルモデル相当でのup,downの層の数
|
||||
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
||||
|
||||
def get_list(name) -> List[float]:
|
||||
def get_list(name_with_suffix) -> List[float]:
|
||||
import math
|
||||
|
||||
tokens = name_with_suffix.split("+")
|
||||
name = tokens[0]
|
||||
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
||||
|
||||
if name == "cosine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in reversed(range(max_len))]
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "sine":
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) for i in range(max_len)]
|
||||
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
||||
elif name == "linear":
|
||||
return [i / (max_len - 1) for i in range(max_len)]
|
||||
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
||||
elif name == "reverse_linear":
|
||||
return [i / (max_len - 1) for i in reversed(range(max_len))]
|
||||
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
||||
elif name == "zeros":
|
||||
return [0.0] * max_len
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
@@ -520,7 +526,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
return
|
||||
|
||||
def get_block_lr_weight(self, lora: LoRAModule) -> float:
|
||||
def get_block_index(self, lora: LoRAModule) -> int:
|
||||
block_idx = -1 # invalid lora name
|
||||
|
||||
m = RE_UPDOWN.search(lora.lora_name)
|
||||
if m:
|
||||
g = m.groups()
|
||||
@@ -533,43 +541,74 @@ class LoRANetwork(torch.nn.Module):
|
||||
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
||||
idx = 3 * i + 2
|
||||
|
||||
if (g[0] == "down") and (self.down_lr_weight != None):
|
||||
return self.down_lr_weight[idx + 1]
|
||||
if (g[0] == "up") and (self.up_lr_weight != None):
|
||||
return self.up_lr_weight[idx]
|
||||
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
|
||||
return self.mid_lr_weight
|
||||
return 1
|
||||
if g[0] == "down":
|
||||
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
||||
elif g[0] == "up":
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
||||
|
||||
elif "mid_block_" in lora.lora_name:
|
||||
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
||||
|
||||
return block_idx
|
||||
|
||||
def get_lr_weight(self, lora: LoRAModule) -> float:
|
||||
lr_weight = 1.0
|
||||
block_idx = self.get_block_index(lora)
|
||||
if block_idx < 0:
|
||||
return lr_weight
|
||||
|
||||
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.down_lr_weight != None:
|
||||
lr_weight = self.down_lr_weight[block_idx]
|
||||
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.mid_lr_weight != None:
|
||||
lr_weight = self.mid_lr_weight
|
||||
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
||||
if self.up_lr_weight != None:
|
||||
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
||||
|
||||
return lr_weight
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
if self.text_encoder_loras:
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
for lora in self.text_encoder_loras:
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
param_data = {"params": params}
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
if self.block_lr:
|
||||
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
block_idx_to_lora = {}
|
||||
for lora in self.unet_loras:
|
||||
param_data = {"params": lora.parameters()}
|
||||
idx = self.get_block_index(lora)
|
||||
if idx not in block_idx_to_lora:
|
||||
block_idx_to_lora[idx] = []
|
||||
block_idx_to_lora[idx].append(lora)
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_block_lr_weight(lora)
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_block_lr_weight(lora)
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
|
||||
else:
|
||||
params = []
|
||||
for lora in self.unet_loras:
|
||||
params.extend(lora.parameters())
|
||||
param_data = {"params": params}
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
@@ -32,16 +32,31 @@ from library.custom_train_functions import apply_snr_weight
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/block{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower():
|
||||
logs[f"lr/d*lr/block{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
@@ -99,8 +114,8 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
@@ -146,7 +161,6 @@ def train(args):
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@@ -214,7 +228,9 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
if is_main_process:
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user