Add LoRA-FA for LoRA+

This commit is contained in:
rockerBOO
2024-04-01 15:43:04 -04:00
parent f99fe281cb
commit c7691607ea

View File

@@ -1033,22 +1033,43 @@ class LoRANetwork(torch.nn.Module):
return lr_weight
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
self.requires_grad_(True)
all_params = []
def enumerate_params(loras: List[LoRAModule]):
params = []
def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
# params.extend(lora.parameters())
params.extend(lora.get_trainable_params())
for name, param in lora.get_trainable_named_params():
if lora_plus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
# assigned_param_groups = ""
# for group in param_groups:
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
# logger.info(assigned_param_groups)
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if lr is not None:
if key == "plus":
param_data["lr"] = lr * lora_plus_ratio
else:
param_data["lr"] = lr
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
params.append(param_data)
return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
all_params.extend(params)
if self.unet_loras:
if self.block_lr:
@@ -1062,21 +1083,15 @@ class LoRANetwork(torch.nn.Module):
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
all_params.extend(params)
else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
all_params.extend(params)
return all_params
@@ -1093,6 +1108,9 @@ class LoRANetwork(torch.nn.Module):
def get_trainable_params(self):
return self.parameters()
def get_trainable_named_params(self):
return self.named_parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None