mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
@@ -2920,6 +2920,9 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
||||
)
|
||||
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
|
||||
@@ -406,27 +406,63 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
logger.info(f"weights are merged")
|
||||
"""
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
default_lr,
|
||||
text_encoder_loraplus_ratio=None,
|
||||
unet_loraplus_ratio=None,
|
||||
loraplus_ratio=None
|
||||
):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_B" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
text_encoder_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.unet_loras,
|
||||
default_lr if unet_lr is None else unet_lr,
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
|
||||
@@ -1034,21 +1034,55 @@ class LoRANetwork(torch.nn.Module):
|
||||
return lr_weight
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
default_lr,
|
||||
text_encoder_loraplus_ratio=None,
|
||||
unet_loraplus_ratio=None,
|
||||
loraplus_ratio=None
|
||||
):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
print("NO LR skipping!")
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
text_encoder_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
if self.block_lr:
|
||||
@@ -1062,21 +1096,20 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
block_loras,
|
||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
else:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
|
||||
@@ -1033,22 +1033,54 @@ class LoRANetwork(torch.nn.Module):
|
||||
return lr_weight
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
default_lr,
|
||||
text_encoder_loraplus_ratio=None,
|
||||
unet_loraplus_ratio=None,
|
||||
loraplus_ratio=None
|
||||
):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras: List[LoRAModule]):
|
||||
params = []
|
||||
def assemble_params(loras, lr, ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
# params.extend(lora.parameters())
|
||||
params.extend(lora.get_trainable_params())
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
text_encoder_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
if self.block_lr:
|
||||
@@ -1062,21 +1094,20 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
# blockごとにパラメータを設定する
|
||||
for idx, block_loras in block_idx_to_lora.items():
|
||||
param_data = {"params": enumerate_params(block_loras)}
|
||||
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||
elif default_lr is not None:
|
||||
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
block_loras,
|
||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
else:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
params = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
return all_params
|
||||
|
||||
@@ -1093,6 +1124,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def get_trainable_named_params(self):
|
||||
return self.named_parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
@@ -64,34 +64,69 @@ class NetworkTrainer:
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
if len(lrs) > 4:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
lora_plus = ""
|
||||
group_id = i
|
||||
|
||||
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
|
||||
lora_plus = '_lora+' if i % 2 == 1 else ''
|
||||
group_id = int((i / 2) + (i % 2 + 0.5))
|
||||
|
||||
logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
else:
|
||||
if args.network_train_text_encoder_only:
|
||||
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/textencoder_lora+"] = float(lrs[1])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
|
||||
elif args.network_train_unet_only:
|
||||
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
logs["lr/unet_lora+"] = float(lrs[1])
|
||||
else:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
else:
|
||||
if len(lrs) == 2:
|
||||
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/textencoder_lora+"] = float(lrs[1])
|
||||
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
logs["lr/unet_lora+"] = float(lrs[1])
|
||||
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
|
||||
logs["lr/all"] = float(lrs[0])
|
||||
logs["lr/all_lora+"] = float(lrs[1])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1])
|
||||
elif len(lrs) == 4:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/textencoder_lora+"] = float(lrs[1])
|
||||
logs["lr/unet"] = float(lrs[2])
|
||||
logs["lr/unet_lora+"] = float(lrs[3])
|
||||
else:
|
||||
logs["lr/all"] = float(lrs[0])
|
||||
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
@@ -338,7 +373,7 @@ class NetworkTrainer:
|
||||
|
||||
# 後方互換性を確保するよ
|
||||
try:
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
|
||||
except TypeError:
|
||||
accelerator.print(
|
||||
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||
@@ -347,6 +382,11 @@ class NetworkTrainer:
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
|
||||
assert (
|
||||
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
|
||||
), "LoRA+ and Prodigy/DAdaptation is not supported"
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
||||
Reference in New Issue
Block a user