mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
expand optimizer options and refactor
Refactor code to make it easier to add new optimizers, and support alternate optimizer parameters -move redundant code to train_util for initializing optimizers - add SGD Nesterov optimizers as option (since they are already available) - add new parameters which may be helpful for tuning existing and new optimizers
This commit is contained in:
23
fine_tune.py
23
fine_tune.py
@@ -149,27 +149,7 @@ def train(args):
|
|||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("prepare optimizer, data loader etc.")
|
||||||
|
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||||
# 8-bit Adamを使う
|
|
||||||
if args.use_8bit_adam:
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print("use 8-bit Adam optimizer")
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
elif args.use_lion_optimizer:
|
|
||||||
try:
|
|
||||||
import lion_pytorch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
|
||||||
print("use Lion optimizer")
|
|
||||||
optimizer_class = lion_pytorch.Lion
|
|
||||||
else:
|
|
||||||
optimizer_class = torch.optim.AdamW
|
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
|
||||||
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@@ -351,6 +331,7 @@ if __name__ == '__main__':
|
|||||||
train_util.add_dataset_arguments(parser, False, True, True)
|
train_util.add_dataset_arguments(parser, False, True, True)
|
||||||
train_util.add_training_arguments(parser, False)
|
train_util.add_training_arguments(parser, False)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||||
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
||||||
|
|||||||
@@ -1366,6 +1366,26 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
||||||
|
|
||||||
|
|
||||||
|
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||||
|
help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit")
|
||||||
|
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
|
parser.add_argument("--optimizer_momentum", type=float, default=0.9,
|
||||||
|
help="Momentum value for optimizers")
|
||||||
|
parser.add_argument("--optimizer_weightdecay", type=float, default=0.01,
|
||||||
|
help="Weight decay for optimizers")
|
||||||
|
parser.add_argument("--optimizer_beta1", type=float, default=0.9,
|
||||||
|
help="beta1 parameter for Adam optimizers")
|
||||||
|
parser.add_argument("--optimizer_beta2", type=float, default=0.999,
|
||||||
|
help="beta2 parameter for Adam optimizers")
|
||||||
|
|
||||||
|
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||||
|
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||||
|
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||||
|
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||||
|
|
||||||
|
|
||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
parser.add_argument("--output_dir", type=str, default=None,
|
parser.add_argument("--output_dir", type=str, default=None,
|
||||||
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
@@ -1387,10 +1407,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
||||||
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
||||||
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
||||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
|
||||||
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
|
||||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
|
||||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
|
||||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||||
parser.add_argument("--xformers", action="store_true",
|
parser.add_argument("--xformers", action="store_true",
|
||||||
@@ -1398,7 +1414,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--vae", type=str, default=None,
|
parser.add_argument("--vae", type=str, default=None,
|
||||||
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
||||||
|
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
parser.add_argument("--max_train_epochs", type=int, default=None,
|
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||||
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||||
@@ -1419,10 +1434,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--logging_dir", type=str, default=None,
|
parser.add_argument("--logging_dir", type=str, default=None,
|
||||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
|
||||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
|
||||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
|
||||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
|
||||||
parser.add_argument("--noise_offset", type=float, default=None,
|
parser.add_argument("--noise_offset", type=float, default=None,
|
||||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
||||||
parser.add_argument("--lowram", action="store_true",
|
parser.add_argument("--lowram", action="store_true",
|
||||||
@@ -1503,6 +1514,58 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
# region utils
|
# region utils
|
||||||
|
|
||||||
|
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit"
|
||||||
|
|
||||||
|
def get_optimizer(args, trainable_params):
|
||||||
|
# Prepare optimizer/学習に必要なクラスを準備する
|
||||||
|
optimizer_type = args.optimizer_type.lower()
|
||||||
|
|
||||||
|
betas = (args.optimizer_beta1, args.optimizer_beta2)
|
||||||
|
weight_decay = args.optimizer_weightdecay
|
||||||
|
momentum = args.optimizer_momentum
|
||||||
|
lr = args.learning_rate
|
||||||
|
|
||||||
|
if optimizer_type == "AdamW8bit".lower():
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||||
|
print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||||
|
optimizer_class = bnb.optim.AdamW8bit
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||||
|
print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||||
|
optimizer_class = bnb.optim.SGD8bit
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||||
|
|
||||||
|
elif optimizer_type == "Lion".lower():
|
||||||
|
try:
|
||||||
|
import lion_pytorch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||||
|
print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||||
|
optimizer_class = lion_pytorch.Lion
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
elif optimizer_type == "SGDNesterov".lower():
|
||||||
|
print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}")
|
||||||
|
optimizer_class = torch.optim.SGD
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}")
|
||||||
|
optimizer_class = torch.optim.AdamW
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||||
|
|
||||||
|
return optimizer_name, optimizer
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||||
# backward compatibility
|
# backward compatibility
|
||||||
|
|||||||
23
train_db.py
23
train_db.py
@@ -115,32 +115,12 @@ def train(args):
|
|||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
# 8-bit Adamを使う
|
|
||||||
if args.use_8bit_adam:
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print("use 8-bit Adam optimizer")
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
elif args.use_lion_optimizer:
|
|
||||||
try:
|
|
||||||
import lion_pytorch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
|
||||||
print("use Lion optimizer")
|
|
||||||
optimizer_class = lion_pytorch.Lion
|
|
||||||
else:
|
|
||||||
optimizer_class = torch.optim.AdamW
|
|
||||||
|
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||||
else:
|
else:
|
||||||
trainable_params = unet.parameters()
|
trainable_params = unet.parameters()
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@@ -352,6 +332,7 @@ if __name__ == '__main__':
|
|||||||
train_util.add_dataset_arguments(parser, True, False, True)
|
train_util.add_dataset_arguments(parser, True, False, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument("--no_token_padding", action="store_true",
|
parser.add_argument("--no_token_padding", action="store_true",
|
||||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||||
|
|||||||
@@ -208,30 +208,8 @@ def train(args):
|
|||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
# 8-bit Adamを使う
|
|
||||||
if args.use_8bit_adam:
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print("use 8-bit Adam optimizer")
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
elif args.use_lion_optimizer:
|
|
||||||
try:
|
|
||||||
import lion_pytorch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
|
||||||
print("use Lion optimizer")
|
|
||||||
optimizer_class = lion_pytorch.Lion
|
|
||||||
else:
|
|
||||||
optimizer_class = torch.optim.AdamW
|
|
||||||
|
|
||||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
|
||||||
|
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||||
|
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@@ -555,6 +533,7 @@ if __name__ == '__main__':
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||||
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
|
|||||||
@@ -198,29 +198,8 @@ def train(args):
|
|||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
# 8-bit Adamを使う
|
|
||||||
if args.use_8bit_adam:
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print("use 8-bit Adam optimizer")
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
elif args.use_lion_optimizer:
|
|
||||||
try:
|
|
||||||
import lion_pytorch
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
|
||||||
print("use Lion optimizer")
|
|
||||||
optimizer_class = lion_pytorch.Lion
|
|
||||||
else:
|
|
||||||
optimizer_class = torch.optim.AdamW
|
|
||||||
|
|
||||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||||
|
optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@@ -491,6 +470,7 @@ if __name__ == '__main__':
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, True, False)
|
train_util.add_dataset_arguments(parser, True, True, False)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||||
|
|||||||
Reference in New Issue
Block a user