diff --git a/library/train_util.py b/library/train_util.py index cc9ac455..7e94158c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2657,7 +2657,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -3373,7 +3373,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -3477,6 +3477,20 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW".lower(): + print(f"use PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW + except AttributeError: + raise AttributeError( + "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW32bit".lower(): print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: