From 0c9c90a87ed9761000d74cce09af52565fa5fc45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 26 Apr 2023 16:15:33 +0800 Subject: [PATCH 01/10] Update train_util.py for add DAdaptAdan and DAdaptSGD --- library/train_util.py | 58 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 8c6e3437..f8277265 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2563,7 +2563,65 @@ def get_optimizer(args, trainable_params): optimizer_class = dadaptation.DAdaptAdam optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "DAdaptAdan".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + optimizer_class = dadaptation.DAdaptAdan + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "DAdaptSGD".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + optimizer_class = dadaptation.DAdaptSGD + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する if "relative_step" not in optimizer_kwargs: From 0db2eddacebec595163cf1da84924ca2e0d79b71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 26 Apr 2023 16:51:12 +0800 Subject: [PATCH 02/10] Update train_util.py for DAdaptadam --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index f8277265..c80c4cca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2535,7 +2535,7 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - elif optimizer_type == "DAdaptation".lower(): + elif optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower(): try: import dadaptation except ImportError: From b357b6dbbc17b8faffedb18197055845cc3a6180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 26 Apr 2023 16:52:27 +0800 Subject: [PATCH 03/10] Update train_network.py for dadapt --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 5c4d5ad1..743ac75b 100644 --- a/train_network.py +++ b/train_network.py @@ -44,7 +44,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -54,7 +54,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower() == "DAdaptation".lower(): + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) From f675b097b0d36cdf9bce3a81e7481ba838215c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:08:32 +0800 Subject: [PATCH 04/10] Update train_README-ja.md for DAdapt --- train_README-ja.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_README-ja.md b/train_README-ja.md index fd66458a..687bdc3c 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -565,7 +565,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - 過去のバージョンの--use_lion_optimizer指定時と同じ - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - SGDNesterov8bit : 引数は同上 - - DAdaptation : https://github.com/facebookresearch/dadaptation + - DAdaptation(DAdaptAdam) : https://github.com/facebookresearch/dadaptation + - DAdaptAdaGrad : 引数は同上 + - DAdaptAdan : 引数は同上 + - DAdaptSGD : 引数は同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ From eccc07537bf89e6fbd8f46c2855f3e5cd42ff952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:17:32 +0800 Subject: [PATCH 05/10] Update train_util.py for DAdapt --- library/train_util.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c80c4cca..e7498f07 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1883,7 +1883,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2448,7 +2448,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -2563,6 +2563,35 @@ def get_optimizer(args, trainable_params): optimizer_class = dadaptation.DAdaptAdam optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "DAdaptAdaGrad".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + optimizer_class = dadaptation.DAdaptAdaGrad + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "DAdaptAdan".lower(): try: From 725cce77e4fc61cfd268fc247fb93f5b29d051a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:23:32 +0800 Subject: [PATCH 06/10] Update train_network.py for DAdaptAdaGrad --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 743ac75b..a864cdd7 100644 --- a/train_network.py +++ b/train_network.py @@ -44,7 +44,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value of unet. + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -54,7 +54,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) From 871644b64271526f0d6ee43770b4103b3f2a6a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:24:21 +0800 Subject: [PATCH 07/10] Update train_db.py for DAdapt --- train_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_db.py b/train_db.py index 178d5cb4..cb566fa6 100644 --- a/train_db.py +++ b/train_db.py @@ -361,7 +361,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) From 208ad06be063d5d5283685f0e1be09cea395e761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:25:54 +0800 Subject: [PATCH 08/10] Update fine_tune.py for DAdapt --- fine_tune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index b6a8d1d7..4227dd04 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -376,7 +376,7 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) From 66aeef40fea563034ee179608726c1d566ce7a56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:26:40 +0800 Subject: [PATCH 09/10] Update train_textual_inversion.py for DAdapt --- train_textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index fb6b6053..72a40fc0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -460,7 +460,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) From 6d0ecb74493c809c82a196147d42bb7c0209c014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 27 Apr 2023 00:27:07 +0800 Subject: [PATCH 10/10] Update train_textual_inversion_XTI.py for DAdapt --- train_textual_inversion_XTI.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 69ec3eb1..576a30a6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -499,7 +499,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] )