diff --git a/library/train_util.py b/library/train_util.py index 048ed2ce..15c23f3c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,9 +2920,6 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) - parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") - parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") - parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/lora.py b/networks/lora.py index edbbdc0d..b67c59bd 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -490,6 +490,14 @@ def create_network( varbose=True, ) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) @@ -1033,18 +1041,27 @@ class LoRANetwork(torch.nn.Module): return lr_weight + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + self.requires_grad_(True) + all_params = [] + lr_descriptions = [] def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} @@ -1056,6 +1073,7 @@ class LoRANetwork(torch.nn.Module): param_groups["lora"][f"{lora.lora_name}.{name}"] = param params = [] + descriptions = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} @@ -1069,20 +1087,22 @@ class LoRANetwork(torch.nn.Module): param_data["lr"] = lr if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: - print("NO LR skipping!") + logger.info("NO LR skipping!") continue params.append(param_data) + descriptions.append("plus" if key == "plus" else "") - return params + return params, descriptions if self.text_encoder_loras: - params = assemble_params( + params, descriptions = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) if self.unet_loras: if self.block_lr: @@ -1096,22 +1116,24 @@ class LoRANetwork(torch.nn.Module): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - params = assemble_params( + params, descriptions = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) else: - params = assemble_params( + params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) - return all_params + return all_params, lr_descriptions def enable_gradient_checkpointing(self): # not supported diff --git a/train_network.py b/train_network.py index 9670490a..c43241e8 100644 --- a/train_network.py +++ b/train_network.py @@ -53,7 +53,15 @@ class NetworkTrainer: # TODO 他のスクリプトと共通化する def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -63,68 +71,25 @@ class NetworkTrainer: logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() - - if len(lrs) > 4: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - lora_plus = "" - group_id = i - - if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - lora_plus = '_lora+' if i % 2 == 1 else '' - group_id = int((i / 2) + (i % 2 + 0.5)) - - logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - - else: - if args.network_train_text_encoder_only: - if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - else: - logs["lr/textencoder"] = float(lrs[0]) - - elif args.network_train_unet_only: - if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - logs["lr/unet"] = float(lrs[0]) - logs["lr/unet_lora+"] = float(lrs[1]) - else: - logs["lr/unet"] = float(lrs[0]) + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] else: - if len(lrs) == 2: - if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: - logs["lr/unet"] = float(lrs[0]) - logs["lr/unet_lora+"] = float(lrs[1]) - elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: - logs["lr/all"] = float(lrs[0]) - logs["lr/all_lora+"] = float(lrs[1]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) - elif len(lrs) == 4: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - logs["lr/unet"] = float(lrs[2]) - logs["lr/unet_lora+"] = float(lrs[3]) + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" else: - logs["lr/all"] = float(lrs[0]) + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) return logs @@ -358,6 +323,7 @@ class NetworkTrainer: network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: + # FIXME consider alpha of weights info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -373,20 +339,23 @@ class NetworkTrainer: # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) + results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) + # accelerator.print( + # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + # ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + lr_descriptions = None + print(lr_descriptions) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - assert ( - (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) - ), "LoRA+ and Prodigy/DAdaptation is not supported" - # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -992,7 +961,9 @@ class NetworkTrainer: progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + ) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") return parser