From 5845de7d7c6c9d8dd6123e7b29f39302e8a8140a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 21:47:37 +0900 Subject: [PATCH] common lr checking for dadaptation and prodigy --- library/train_util.py | 114 +++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 67 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5b5d99ac..acfb503b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2752,15 +2752,7 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - elif optimizer_type.startswith("DAdapt".lower()): - # DAdaptation family - # check dadaptation is installed - try: - import dadaptation - import dadaptation.experimental as experimental - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - + elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): # check lr and lr_count, and print warning actual_lr = lr lr_count = 1 @@ -2773,72 +2765,60 @@ def get_optimizer(args, trainable_params): if actual_lr <= 0.1: print( - f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) print("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) - # set optimizer - if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): - optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdaGrad".lower(): - optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdam".lower(): - optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdan".lower(): - optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdanIP".lower(): - optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptLion".lower(): - optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptSGD".lower(): - optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + if optimizer_type.startswith("DAdapt".lower()): + # DAdaptation family + # check dadaptation is installed + try: + import dadaptation + import dadaptation.experimental as experimental + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + + # set optimizer + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): + optimizer_class = experimental.DAdaptAdamPreprint + print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdaGrad".lower(): + optimizer_class = dadaptation.DAdaptAdaGrad + print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdan".lower(): + optimizer_class = dadaptation.DAdaptAdan + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdanIP".lower(): + optimizer_class = experimental.DAdaptAdanIP + print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptLion".lower(): + optimizer_class = dadaptation.DAdaptLion + print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptSGD".lower(): + optimizer_class = dadaptation.DAdaptSGD + print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "Prodigy".lower(): - # Prodigy - # check Prodigy is installed - try: - import prodigyopt - except ImportError: - raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - - # check lr and lr_count, and print warning - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" - ) - print("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - print( - f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - print(f"use Prodigy optimizer | {optimizer_kwargs}") - optimizer_class = prodigyopt.Prodigy - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する