mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
move loraplus args from args to network_args, simplify log lr desc
This commit is contained in:
@@ -2920,9 +2920,6 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=1,
|
default=1,
|
||||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
||||||
)
|
)
|
||||||
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
|
||||||
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
|
||||||
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
|
||||||
|
|
||||||
|
|
||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
|
|||||||
@@ -490,6 +490,14 @@ def create_network(
|
|||||||
varbose=True,
|
varbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||||
|
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||||
|
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||||
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||||
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||||
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||||
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||||
|
|
||||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||||
|
|
||||||
@@ -1033,18 +1041,27 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
|
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||||
|
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||||
|
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||||
|
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
self,
|
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
||||||
text_encoder_lr,
|
# if (
|
||||||
unet_lr,
|
# self.loraplus_lr_ratio is not None
|
||||||
default_lr,
|
# or self.loraplus_text_encoder_lr_ratio is not None
|
||||||
text_encoder_loraplus_ratio=None,
|
# or self.loraplus_unet_lr_ratio is not None
|
||||||
unet_loraplus_ratio=None,
|
# ):
|
||||||
loraplus_ratio=None
|
# assert (
|
||||||
):
|
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
||||||
|
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
||||||
|
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
|
|
||||||
all_params = []
|
all_params = []
|
||||||
|
lr_descriptions = []
|
||||||
|
|
||||||
def assemble_params(loras, lr, ratio):
|
def assemble_params(loras, lr, ratio):
|
||||||
param_groups = {"lora": {}, "plus": {}}
|
param_groups = {"lora": {}, "plus": {}}
|
||||||
@@ -1056,6 +1073,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
|
||||||
params = []
|
params = []
|
||||||
|
descriptions = []
|
||||||
for key in param_groups.keys():
|
for key in param_groups.keys():
|
||||||
param_data = {"params": param_groups[key].values()}
|
param_data = {"params": param_groups[key].values()}
|
||||||
|
|
||||||
@@ -1069,20 +1087,22 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
param_data["lr"] = lr
|
param_data["lr"] = lr
|
||||||
|
|
||||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||||
print("NO LR skipping!")
|
logger.info("NO LR skipping!")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
params.append(param_data)
|
params.append(param_data)
|
||||||
|
descriptions.append("plus" if key == "plus" else "")
|
||||||
|
|
||||||
return params
|
return params, descriptions
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
params = assemble_params(
|
params, descriptions = assemble_params(
|
||||||
self.text_encoder_loras,
|
self.text_encoder_loras,
|
||||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
text_encoder_loraplus_ratio or loraplus_ratio
|
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
if self.block_lr:
|
if self.block_lr:
|
||||||
@@ -1096,22 +1116,24 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# blockごとにパラメータを設定する
|
# blockごとにパラメータを設定する
|
||||||
for idx, block_loras in block_idx_to_lora.items():
|
for idx, block_loras in block_idx_to_lora.items():
|
||||||
params = assemble_params(
|
params, descriptions = assemble_params(
|
||||||
block_loras,
|
block_loras,
|
||||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
||||||
unet_loraplus_ratio or loraplus_ratio
|
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
params = assemble_params(
|
params, descriptions = assemble_params(
|
||||||
self.unet_loras,
|
self.unet_loras,
|
||||||
unet_lr if unet_lr is not None else default_lr,
|
unet_lr if unet_lr is not None else default_lr,
|
||||||
unet_loraplus_ratio or loraplus_ratio
|
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
return all_params
|
return all_params, lr_descriptions
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
def enable_gradient_checkpointing(self):
|
||||||
# not supported
|
# not supported
|
||||||
|
|||||||
114
train_network.py
114
train_network.py
@@ -53,7 +53,15 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(
|
def generate_step_logs(
|
||||||
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
|
self,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
current_loss,
|
||||||
|
avr_loss,
|
||||||
|
lr_scheduler,
|
||||||
|
lr_descriptions,
|
||||||
|
keys_scaled=None,
|
||||||
|
mean_norm=None,
|
||||||
|
maximum_norm=None,
|
||||||
):
|
):
|
||||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||||
|
|
||||||
@@ -63,68 +71,25 @@ class NetworkTrainer:
|
|||||||
logs["max_norm/max_key_norm"] = maximum_norm
|
logs["max_norm/max_key_norm"] = maximum_norm
|
||||||
|
|
||||||
lrs = lr_scheduler.get_last_lr()
|
lrs = lr_scheduler.get_last_lr()
|
||||||
|
for i, lr in enumerate(lrs):
|
||||||
if len(lrs) > 4:
|
if lr_descriptions is not None:
|
||||||
idx = 0
|
lr_desc = lr_descriptions[i]
|
||||||
if not args.network_train_unet_only:
|
|
||||||
logs["lr/textencoder"] = float(lrs[0])
|
|
||||||
idx = 1
|
|
||||||
|
|
||||||
for i in range(idx, len(lrs)):
|
|
||||||
lora_plus = ""
|
|
||||||
group_id = i
|
|
||||||
|
|
||||||
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
|
|
||||||
lora_plus = '_lora+' if i % 2 == 1 else ''
|
|
||||||
group_id = int((i / 2) + (i % 2 + 0.5))
|
|
||||||
|
|
||||||
logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i])
|
|
||||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
|
||||||
logs[f"lr/d*lr/group{group_id}{lora_plus}"] = (
|
|
||||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if args.network_train_text_encoder_only:
|
|
||||||
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None:
|
|
||||||
logs["lr/textencoder"] = float(lrs[0])
|
|
||||||
logs["lr/textencoder_lora+"] = float(lrs[1])
|
|
||||||
else:
|
|
||||||
logs["lr/textencoder"] = float(lrs[0])
|
|
||||||
|
|
||||||
elif args.network_train_unet_only:
|
|
||||||
if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
|
|
||||||
logs["lr/unet"] = float(lrs[0])
|
|
||||||
logs["lr/unet_lora+"] = float(lrs[1])
|
|
||||||
else:
|
|
||||||
logs["lr/unet"] = float(lrs[0])
|
|
||||||
else:
|
else:
|
||||||
if len(lrs) == 2:
|
idx = i - (0 if args.network_train_unet_only else -1)
|
||||||
if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None:
|
if idx == -1:
|
||||||
logs["lr/textencoder"] = float(lrs[0])
|
lr_desc = "textencoder"
|
||||||
logs["lr/textencoder_lora+"] = float(lrs[1])
|
|
||||||
elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None:
|
|
||||||
logs["lr/unet"] = float(lrs[0])
|
|
||||||
logs["lr/unet_lora+"] = float(lrs[1])
|
|
||||||
elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None:
|
|
||||||
logs["lr/all"] = float(lrs[0])
|
|
||||||
logs["lr/all_lora+"] = float(lrs[1])
|
|
||||||
else:
|
|
||||||
logs["lr/textencoder"] = float(lrs[0])
|
|
||||||
logs["lr/unet"] = float(lrs[-1])
|
|
||||||
elif len(lrs) == 4:
|
|
||||||
logs["lr/textencoder"] = float(lrs[0])
|
|
||||||
logs["lr/textencoder_lora+"] = float(lrs[1])
|
|
||||||
logs["lr/unet"] = float(lrs[2])
|
|
||||||
logs["lr/unet_lora+"] = float(lrs[3])
|
|
||||||
else:
|
else:
|
||||||
logs["lr/all"] = float(lrs[0])
|
if len(lrs) > 2:
|
||||||
|
lr_desc = f"group{idx}"
|
||||||
|
else:
|
||||||
|
lr_desc = "unet"
|
||||||
|
|
||||||
if (
|
logs[f"lr/{lr_desc}"] = lr
|
||||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
|
||||||
): # tracking d*lr value of unet.
|
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||||
logs["lr/d*lr"] = (
|
# tracking d*lr value
|
||||||
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
logs[f"lr/d*lr/{lr_desc}"] = (
|
||||||
|
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return logs
|
return logs
|
||||||
@@ -358,6 +323,7 @@ class NetworkTrainer:
|
|||||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
|
# FIXME consider alpha of weights
|
||||||
info = network.load_weights(args.network_weights)
|
info = network.load_weights(args.network_weights)
|
||||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||||
|
|
||||||
@@ -373,20 +339,23 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# 後方互換性を確保するよ
|
# 後方互換性を確保するよ
|
||||||
try:
|
try:
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio)
|
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||||
|
if type(results) is tuple:
|
||||||
|
trainable_params = results[0]
|
||||||
|
lr_descriptions = results[1]
|
||||||
|
else:
|
||||||
|
trainable_params = results
|
||||||
|
lr_descriptions = None
|
||||||
except TypeError:
|
except TypeError:
|
||||||
accelerator.print(
|
# accelerator.print(
|
||||||
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||||
)
|
# )
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||||
|
lr_descriptions = None
|
||||||
|
print(lr_descriptions)
|
||||||
|
|
||||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None:
|
|
||||||
assert (
|
|
||||||
(optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name)
|
|
||||||
), "LoRA+ and Prodigy/DAdaptation is not supported"
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
@@ -992,7 +961,9 @@ class NetworkTrainer:
|
|||||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
logs = self.generate_step_logs(
|
||||||
|
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
|
||||||
|
)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
@@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
)
|
)
|
||||||
|
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||||
|
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||||
|
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user