diff --git a/library/train_util.py b/library/train_util.py index 8c6e3437..a8ee260c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1883,7 +1883,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, Lion8bit,SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", ) # backward compatibility @@ -2448,7 +2448,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, Lion8bit, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -2525,6 +2525,15 @@ def get_optimizer(args, trainable_params): print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Lion8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.Lion8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "SGDNesterov".lower(): print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") diff --git a/train_README-ja.md b/train_README-ja.md index fd66458a..a155febd 100644 --- a/train_README-ja.md +++ b/train_README-ja.md @@ -563,6 +563,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - 過去のバージョンの--use_8bit_adam指定時と同じ - Lion : https://github.com/lucidrains/lion-pytorch - 過去のバージョンの--use_lion_optimizer指定時と同じ + - Lion8bit : 引数は同上 - SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True - SGDNesterov8bit : 引数は同上 - DAdaptation : https://github.com/facebookresearch/dadaptation