From 335b2f960e24be4a4ae4a258cf210318502f9de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Wed, 3 May 2023 09:22:40 +0800 Subject: [PATCH] Support for Lion8bit (#447) * ADD libbitsandbytes.dll for 0.38.1 * Delete libbitsandbytes_cuda116.dll * Delete cextension.py * add main.py * Update requirements.txt for bitsandbytes 0.38.1 * Update README.md for bitsandbytes-windows * Update README-ja.md for bitsandbytes 0.38.1 * Update main.py for return cuda118 * Update train_util.py for lion8bit * Update train_README-ja.md for lion8bit * Update train_util.py for add DAdaptAdan and DAdaptSGD * Update train_util.py for DAdaptadam * Update train_network.py for dadapt * Update train_README-ja.md for DAdapt * Update train_util.py for DAdapt * Update train_network.py for DAdaptAdaGrad * Update train_db.py for DAdapt * Update fine_tune.py for DAdapt * Update train_textual_inversion.py for DAdapt * Update train_textual_inversion_XTI.py for DAdapt * Revert "Merge branch 'qinglong' into main" This reverts commit b65c023083d6d1e8a30eb42eddd603d1aac97650, reversing changes made to f6fda20caf5e773d56bcfb5c4575c650bb85362b. * Revert "Update requirements.txt for bitsandbytes 0.38.1" This reverts commit 83abc60dfaddb26845f54228425b98dd67997528. * Revert "Delete cextension.py" This reverts commit 3ba4dfe046874393f2a022a4cbef3628ada35391. * Revert "Update README.md for bitsandbytes-windows" This reverts commit 4642c52086b5e9791233007e2fdfd97f832cd897. * Revert "Update README-ja.md for bitsandbytes 0.38.1" This reverts commit fa6d7485ac067ebc49e6f381afdb8dd2f12caa8f. * Revert "ADD libbitsandbytes.dll for 0.38.1" This reverts commit bee1e6f731d2428dacb34b61997f06143c69c278. * Revert "Delete libbitsandbytes_cuda116.dll" This reverts commit 891c7e92623dab92f3767663982627cca6a26724. * reverse main.py * Reverse main.py --- library/train_util.py | 13 +++++++++++-- train_README-ja.md | 1 + 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8c6e3437..a8ee260c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1883,7 +1883,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, Lion8bit,SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", ) # backward compatibility @@ -2448,7 +2448,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, Lion8bit, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -2525,6 +2525,15 @@ def get_optimizer(args, trainable_params): print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Lion8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.Lion8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") diff --git a/train_README-ja.md b/train_README-ja.md index fd66458a..a155febd 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -563,6 +563,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - 過去のバージョンの--use_8bit_adam指定時と同じ - Lion : https://github.com/lucidrains/lion-pytorch - 過去のバージョンの--use_lion_optimizer指定時と同じ + - Lion8bit : 引数は同上 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - SGDNesterov8bit : 引数は同上 - DAdaptation : https://github.com/facebookresearch/dadaptation