diff --git a/library/train_util.py b/library/train_util.py index a8ee260c..6c064738 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2525,14 +2525,21 @@ def get_optimizer(args, trainable_params): print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - + elif optimizer_type == "Lion8bit".lower(): try: import bitsandbytes as bnb except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.Lion8bit + try: + optimizer_class = bnb.optim.Lion8bit + except AttributeError: + raise AttributeError( + "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower():