expand optimizer options and refactor

Refactor code to make it easier to add new optimizers, and support alternate optimizer parameters

-move redundant code to train_util for initializing optimizers
- add SGD Nesterov optimizers as option (since they are already available)
- add new parameters which may be helpful for tuning existing and new optimizers
This commit is contained in:
mgz-dev
2023-02-19 17:45:09 -06:00
parent 08ae46b163
commit b29c5a750c
5 changed files with 80 additions and 96 deletions

View File

@@ -1366,6 +1366,26 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--optimizer_type", type=str, default="AdamW",
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
help="Momentum value for optimizers")
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
help="Weight decay for optimizers")
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
help="beta1 parameter for Adam optimizers")
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
help="beta2 parameter for Adam optimizers")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
@@ -1387,10 +1407,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
@@ -1398,7 +1414,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--vae", type=str, default=None,
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None,
help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
@@ -1419,10 +1434,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--noise_offset", type=float, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)")
parser.add_argument("--lowram", action="store_true",
@@ -1503,6 +1514,58 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
# region utils
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit"
def get_optimizer(args, trainable_params):
# Prepare optimizer/学習に必要なクラスを準備する
optimizer_type = args.optimizer_type.lower()
betas = (args.optimizer_beta1, args.optimizer_beta2)
weight_decay = args.optimizer_weightdecay
momentum = args.optimizer_momentum
lr = args.learning_rate
if optimizer_type == "AdamW8bit".lower():
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = bnb.optim.AdamW8bit
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
elif optimizer_type == "SGDNesterov8bit".lower():
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
optimizer_class = bnb.optim.SGD8bit
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
elif optimizer_type == "Lion".lower():
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
elif optimizer_type == "SGDNesterov".lower():
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
else:
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
return optimizer_name, optimizer
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
# backward compatibility