diff --git a/fine_tune.py b/fine_tune.py index 13241bc6..a3588c37 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -149,27 +149,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - - # 8-bit Adamを使う - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print("use 8-bit Adam optimizer") - optimizer_class = bnb.optim.AdamW8bit - elif args.use_lion_optimizer: - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print("use Lion optimizer") - optimizer_class = lion_pytorch.Lion - else: - optimizer_class = torch.optim.AdamW - - # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) + optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -351,6 +331,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, False, True, True) train_util.add_training_arguments(parser, False) train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') diff --git a/library/train_util.py b/library/train_util.py index 63868f98..581ad77f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1366,6 +1366,26 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") +def add_optimizer_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--optimizer_type", type=str, default="AdamW", + help="Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit") + + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--optimizer_momentum", type=float, default=0.9, + help="Momentum value for optimizers") + parser.add_argument("--optimizer_weightdecay", type=float, default=0.01, + help="Weight decay for optimizers") + parser.add_argument("--optimizer_beta1", type=float, default=0.9, + help="beta1 parameter for Adam optimizers") + parser.add_argument("--optimizer_beta2", type=float, default=0.999, + help="beta2 parameter for Adam optimizers") + + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + + def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") @@ -1387,10 +1407,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--use_lion_optimizer", action="store_true", - help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") parser.add_argument("--mem_eff_attn", action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") parser.add_argument("--xformers", action="store_true", @@ -1398,7 +1414,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") @@ -1419,10 +1434,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--logging_dir", type=str, default=None, help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") parser.add_argument("--noise_offset", type=float, default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") parser.add_argument("--lowram", action="store_true", @@ -1503,6 +1514,58 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): # region utils +# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit" + +def get_optimizer(args, trainable_params): + # Prepare optimizer/学習に必要なクラスを準備する + optimizer_type = args.optimizer_type.lower() + + betas = (args.optimizer_beta1, args.optimizer_beta2) + weight_decay = args.optimizer_weightdecay + momentum = args.optimizer_momentum + lr = args.learning_rate + + if optimizer_type == "AdamW8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + + elif optimizer_type == "SGDNesterov8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}") + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) + + elif optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print(f"use Lion optimizer | betas: {betas}, Weight Decay: {weight_decay}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + + elif optimizer_type == "SGDNesterov".lower(): + print(f"use SGD with Nesterov optimizer | Momentum: {momentum}, Weight Decay: {weight_decay}") + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) + + else: + print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, betas=betas, weight_decay=weight_decay) + + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + + return optimizer_name, optimizer + def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): # backward compatibility diff --git a/train_db.py b/train_db.py index 1903c4c4..51e588fc 100644 --- a/train_db.py +++ b/train_db.py @@ -115,32 +115,12 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - - # 8-bit Adamを使う - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print("use 8-bit Adam optimizer") - optimizer_class = bnb.optim.AdamW8bit - elif args.use_lion_optimizer: - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print("use Lion optimizer") - optimizer_class = lion_pytorch.Lion - else: - optimizer_class = torch.optim.AdamW - if train_text_encoder: trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) else: trainable_params = unet.parameters() - # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -352,6 +332,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, False, True) train_util.add_training_arguments(parser, True) train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") diff --git a/train_network.py b/train_network.py index 1489691d..b030391c 100644 --- a/train_network.py +++ b/train_network.py @@ -208,30 +208,8 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - # 8-bit Adamを使う - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print("use 8-bit Adam optimizer") - optimizer_class = bnb.optim.AdamW8bit - elif args.use_lion_optimizer: - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print("use Lion optimizer") - optimizer_class = lion_pytorch.Lion - else: - optimizer_class = torch.optim.AdamW - - optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - - # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -555,6 +533,7 @@ if __name__ == '__main__': train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ffec0516..1913da7e 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -198,29 +198,8 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - - # 8-bit Adamを使う - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print("use 8-bit Adam optimizer") - optimizer_class = bnb.optim.AdamW8bit - elif args.use_lion_optimizer: - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print("use Lion optimizer") - optimizer_class = lion_pytorch.Lion - else: - optimizer_class = torch.optim.AdamW - trainable_params = text_encoder.get_input_embeddings().parameters() - - # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + optimizer_name, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -491,6 +470,7 @@ if __name__ == '__main__': train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, False) train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")