mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
expand optimizer options and refactor
Refactor code to make it easier to add new optimizers, and support alternate optimizer parameters -move redundant code to train_util for initializing optimizers - add SGD Nesterov optimizers as option (since they are already available) - add new parameters which may be helpful for tuning existing and new optimizers
This commit is contained in:
@@ -1366,6 +1366,26 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
|
||||
help="Momentum value for optimizers")
|
||||
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
|
||||
help="Weight decay for optimizers")
|
||||
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
|
||||
help="beta1 parameter for Adam optimizers")
|
||||
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
|
||||
help="beta2 parameter for Adam optimizers")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
parser.add_argument("--output_dir", type=str, default=None,
|
||||
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||
@@ -1387,10 +1407,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
||||
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
||||
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||
parser.add_argument("--xformers", action="store_true",
|
||||
@@ -1398,7 +1414,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--vae", type=str, default=None,
|
||||
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
@@ -1419,10 +1434,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--logging_dir", type=str, default=None,
|
||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--noise_offset", type=float, default=None,
|
||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
||||
parser.add_argument("--lowram", action="store_true",
|
||||
@@ -1503,6 +1514,58 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
# region utils
|
||||
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit"
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# Prepare optimizer/学習に必要なクラスを準備する
|
||||
optimizer_type = args.optimizer_type.lower()
|
||||
|
||||
betas = (args.optimizer_beta1, args.optimizer_beta2)
|
||||
weight_decay = args.optimizer_weightdecay
|
||||
momentum = args.optimizer_momentum
|
||||
lr = args.learning_rate
|
||||
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = bnb.optim.SGD8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
|
||||
elif optimizer_type == "Lion".lower():
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
elif optimizer_type == "SGDNesterov".lower():
|
||||
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||
|
||||
else:
|
||||
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
|
||||
return optimizer_name, optimizer
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
# backward compatibility
|
||||
|
||||
Reference in New Issue
Block a user