mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add LoRA+ support
This commit is contained in:
@@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=1,
|
default=1,
|
||||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||||
|
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||||
|
|
||||||
|
|
||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
|
|||||||
@@ -406,27 +406,48 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"weights are merged")
|
logger.info(f"weights are merged")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
def enumerate_params(loras):
|
def assemble_params(loras, lr, lora_plus_ratio):
|
||||||
params = []
|
param_groups = {"lora": {}, "plus": {}}
|
||||||
for lora in loras:
|
for lora in loras:
|
||||||
params.extend(lora.parameters())
|
for name, param in lora.named_parameters():
|
||||||
|
if lora_plus_ratio is not None and "lora_up" in name:
|
||||||
|
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
else:
|
||||||
|
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
|
||||||
|
# assigned_param_groups = ""
|
||||||
|
# for group in param_groups:
|
||||||
|
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
|
||||||
|
# logger.info(assigned_param_groups)
|
||||||
|
|
||||||
|
params = []
|
||||||
|
for key in param_groups.keys():
|
||||||
|
param_data = {"params": param_groups[key].values()}
|
||||||
|
if lr is not None:
|
||||||
|
if key == "plus":
|
||||||
|
param_data["lr"] = lr * lora_plus_ratio
|
||||||
|
else:
|
||||||
|
param_data["lr"] = lr
|
||||||
|
|
||||||
|
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
params.append(param_data)
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
||||||
if text_encoder_lr is not None:
|
all_params.extend(params)
|
||||||
param_data["lr"] = text_encoder_lr
|
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
||||||
if unet_lr is not None:
|
all_params.extend(params)
|
||||||
param_data["lr"] = unet_lr
|
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|
||||||
|
|||||||
@@ -1035,21 +1035,43 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
def enumerate_params(loras):
|
def assemble_params(loras, lr, lora_plus_ratio):
|
||||||
params = []
|
param_groups = {"lora": {}, "plus": {}}
|
||||||
for lora in loras:
|
for lora in loras:
|
||||||
params.extend(lora.parameters())
|
for name, param in lora.named_parameters():
|
||||||
|
if lora_plus_ratio is not None and "lora_up" in name:
|
||||||
|
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
else:
|
||||||
|
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
|
||||||
|
# assigned_param_groups = ""
|
||||||
|
# for group in param_groups:
|
||||||
|
# assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
|
||||||
|
# logger.info(assigned_param_groups)
|
||||||
|
|
||||||
|
params = []
|
||||||
|
for key in param_groups.keys():
|
||||||
|
param_data = {"params": param_groups[key].values()}
|
||||||
|
if lr is not None:
|
||||||
|
if key == "plus":
|
||||||
|
param_data["lr"] = lr * lora_plus_ratio
|
||||||
|
else:
|
||||||
|
param_data["lr"] = lr
|
||||||
|
|
||||||
|
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||||
|
continue
|
||||||
|
|
||||||
|
params.append(param_data)
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
||||||
if text_encoder_lr is not None:
|
all_params.extend(params)
|
||||||
param_data["lr"] = text_encoder_lr
|
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
if self.block_lr:
|
if self.block_lr:
|
||||||
@@ -1063,21 +1085,15 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# blockごとにパラメータを設定する
|
# blockごとにパラメータを設定する
|
||||||
for idx, block_loras in block_idx_to_lora.items():
|
for idx, block_loras in block_idx_to_lora.items():
|
||||||
param_data = {"params": enumerate_params(block_loras)}
|
|
||||||
|
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
||||||
elif default_lr is not None:
|
elif default_lr is not None:
|
||||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
||||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
all_params.extend(params)
|
||||||
continue
|
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
||||||
if unet_lr is not None:
|
all_params.extend(params)
|
||||||
param_data["lr"] = unet_lr
|
|
||||||
all_params.append(param_data)
|
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# 後方互換性を確保するよ
|
# 後方互換性を確保するよ
|
||||||
try:
|
try:
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
accelerator.print(
|
accelerator.print(
|
||||||
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||||
|
|||||||
Reference in New Issue
Block a user