support individual LR for CLIP-L/T5XXL

This commit is contained in:
Kohya S
2024-09-10 20:32:09 +09:00
parent d29af146b8
commit d10ff62a78
3 changed files with 49 additions and 58 deletions

View File

@@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
# if (
# self.loraplus_lr_ratio is not None
# or self.loraplus_text_encoder_lr_ratio is not None
# or self.loraplus_unet_lr_ratio is not None
# ):
# assert (
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of two elements
if text_encoder_lr is None or len(text_encoder_lr) == 0:
text_encoder_lr = [default_lr, default_lr]
elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, ratio):
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
@@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module):
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
@@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module):
return params, descriptions
if self.text_encoder_loras:
params, descriptions = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# split text encoder loras for te1 and te3
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)]
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te3_loras) > 0:
logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}")
params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
# if self.block_lr:
# is_sdxl = False
# for lora in self.unet_loras:
# if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
# is_sdxl = True
# break
# # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
# block_idx_to_lora = {}
# for lora in self.unet_loras:
# idx = get_block_index(lora.lora_name, is_sdxl)
# if idx not in block_idx_to_lora:
# block_idx_to_lora[idx] = []
# block_idx_to_lora[idx].append(lora)
# # blockごとにパラメータを設定する
# for idx, block_loras in block_idx_to_lora.items():
# params, descriptions = assemble_params(
# block_loras,
# (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
# self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
# )
# all_params.extend(params)
# lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
# else:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,