From 5cdf4e34a1e16d56d436f81860501188161c9640 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 25 May 2023 20:52:36 +0800 Subject: [PATCH] support for dadapaption V3 (#530) * Update train_util.py for DAdaptLion * Update train_README-zh.md for dadaptlion * Update train_README-ja.md for DAdaptLion * add DAdatpt V3 * Alignment * Update train_util.py for experimental * Update train_util.py V3 * Update train_README-zh.md * Update train_README-ja.md * Update train_util.py fix * Update train_util.py --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com> --- docs/train_README-ja.md | 5 ++++- docs/train_README-zh.md | 10 ++++++++-- library/train_util.py | 20 +++++++++++++++----- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index f27c5c65..b64b1808 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -615,9 +615,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - Lion8bit : 引数は同上 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - SGDNesterov8bit : 引数は同上 - - DAdaptation(DAdaptAdam) : https://github.com/facebookresearch/dadaptation + - DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation + - DAdaptAdam : 引数は同上 - DAdaptAdaGrad : 引数は同上 - DAdaptAdan : 引数は同上 + - DAdaptAdanIP : 引数は同上 + - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index dbd26606..678832d2 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -550,8 +550,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - Lion : https://github.com/lucidrains/lion-pytorch - 与过去版本中指定的 --use_lion_optimizer 相同 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - - SGDNesterov8bit : 引数同上 - - DAdaptation : https://github.com/facebookresearch/dadaptation + - SGDNesterov8bit : 参数同上 + - DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation + - DAdaptAdam : 参数同上 + - DAdaptAdaGrad : 参数同上 + - DAdaptAdan : 参数同上 + - DAdaptAdanIP : 引数は同上 + - DAdaptLion : 参数同上 + - DAdaptSGD : 参数同上 - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任何优化器 diff --git a/library/train_util.py b/library/train_util.py index 41afc13b..b3968c43 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1940,7 +1940,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2545,7 +2545,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation, DAdaptation(DAdaptAdam), DAdaptAdaGrad, DAdaptAdan, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, Lion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -2653,6 +2653,7 @@ def get_optimizer(args, trainable_params): # check dadaptation is installed try: import dadaptation + import dadaptation.experimental as experimental except ImportError: raise ImportError("No dadaptation / dadaptation がインストールされていないようです") @@ -2677,15 +2678,24 @@ def get_optimizer(args, trainable_params): ) # set optimizer - if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower(): - optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): + optimizer_class = experimental.DAdaptAdamPreprint + print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdaGrad".lower(): optimizer_class = dadaptation.DAdaptAdaGrad print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptAdan".lower(): optimizer_class = dadaptation.DAdaptAdan print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdanIP".lower(): + optimizer_class = experimental.DAdaptAdanIP + print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptLion".lower(): + optimizer_class = dadaptation.DAdaptLion + print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") elif optimizer_type == "DAdaptSGD".lower(): optimizer_class = dadaptation.DAdaptSGD print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")