Add error message if no Lion8bit

This commit is contained in:
ykume
2023-05-03 10:35:47 +09:00
parent 335b2f960e
commit a7485e4d9e

View File

@@ -2525,14 +2525,21 @@ def get_optimizer(args, trainable_params):
print(f"use Lion optimizer | {optimizer_kwargs}")
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "Lion8bit".lower():
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
optimizer_class = bnb.optim.Lion8bit
try:
optimizer_class = bnb.optim.Lion8bit
except AttributeError:
raise AttributeError(
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov".lower():