mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support individual LR for CLIP-L/T5XXL
This commit is contained in:
@@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
||||
# if (
|
||||
# self.loraplus_lr_ratio is not None
|
||||
# or self.loraplus_text_encoder_lr_ratio is not None
|
||||
# or self.loraplus_unet_lr_ratio is not None
|
||||
# ):
|
||||
# assert (
|
||||
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
||||
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# make sure text_encoder_lr as list of two elements
|
||||
if text_encoder_lr is None or len(text_encoder_lr) == 0:
|
||||
text_encoder_lr = [default_lr, default_lr]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, ratio):
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_up" in name:
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
@@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
@@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module):
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
||||
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
|
||||
# split text encoder loras for te1 and te3
|
||||
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)]
|
||||
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
|
||||
if len(te3_loras) > 0:
|
||||
logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}")
|
||||
params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
# if self.block_lr:
|
||||
# is_sdxl = False
|
||||
# for lora in self.unet_loras:
|
||||
# if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
||||
# is_sdxl = True
|
||||
# break
|
||||
|
||||
# # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
# block_idx_to_lora = {}
|
||||
# for lora in self.unet_loras:
|
||||
# idx = get_block_index(lora.lora_name, is_sdxl)
|
||||
# if idx not in block_idx_to_lora:
|
||||
# block_idx_to_lora[idx] = []
|
||||
# block_idx_to_lora[idx].append(lora)
|
||||
|
||||
# # blockごとにパラメータを設定する
|
||||
# for idx, block_loras in block_idx_to_lora.items():
|
||||
# params, descriptions = assemble_params(
|
||||
# block_loras,
|
||||
# (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
||||
# self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
# )
|
||||
# all_params.extend(params)
|
||||
# lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
# else:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
|
||||
Reference in New Issue
Block a user