From a7485e4d9e27455b4dd95235f8c840d905a6623e Mon Sep 17 00:00:00 2001 From: ykume Date: Wed, 3 May 2023 10:35:47 +0900 Subject: [PATCH] Add error message if no Lion8bit --- library/train_util.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index a8ee260c..6c064738 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2525,14 +2525,21 @@ def get_optimizer(args, trainable_params): print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - + elif optimizer_type == "Lion8bit".lower(): try: import bitsandbytes as bnb except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.Lion8bit + try: + optimizer_class = bnb.optim.Lion8bit + except AttributeError: + raise AttributeError( + "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower():