mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add LoRA+ support
This commit is contained in:
@@ -406,27 +406,48 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
logger.info(f"weights are merged")
|
||||
"""
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
def assemble_params(loras, lr, lora_plus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
for name, param in lora.named_parameters():
|
||||
if lora_plus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
# assigned_param_groups = ""
|
||||
# for group in param_groups:
|
||||
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
|
||||
# logger.info(assigned_param_groups)
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * lora_plus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
|
||||
@@ -1035,21 +1035,43 @@ class LoRANetwork(torch.nn.Module):
|
||||
return lr_weight
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
def assemble_params(loras, lr, lora_plus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
for name, param in lora.named_parameters():
|
||||
if lora_plus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
# assigned_param_groups = ""
|
||||
# for group in param_groups:
|
||||
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
|
||||
# logger.info(assigned_param_groups)
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * lora_plus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
if self.block_lr:
|
||||
@@ -1063,21 +1085,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
||||
all_params.extend(params)
|
||||
|
||||
else:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user