diff --git a/library/train_util.py b/library/train_util.py index 581ad77f..1a28c39a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1407,6 +1407,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + # parser.add_argument("--use_8bit_adam", action="store_true", + # help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + # parser.add_argument("--use_lion_optimizer", action="store_true", + # help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") + # parser.add_argument("--use_dadaptation_optimizer", action="store_true", + # help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)") parser.add_argument("--mem_eff_attn", action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") parser.add_argument("--xformers", action="store_true", @@ -1514,7 +1520,7 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): # region utils -# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit" +# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, dadaption" def get_optimizer(args, trainable_params): # Prepare optimizer/学習に必要なクラスを準備する @@ -1557,6 +1563,17 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=True) + elif optimizer_type == "dadaptation".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use dadaptation optimizer") + optimizer_class = dadaptation.DAdaptAdam + if args.learning_rate <= 0.1: + print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.') + print('recommend option: lr=1.0') + optimizer = optimizer_class(trainable_params, lr=lr) else: print(f"use AdamW optimizer | betas: {betas}, Weight Decay: {weight_decay}") optimizer_class = torch.optim.AdamW diff --git a/train_network.py b/train_network.py index b030391c..df987325 100644 --- a/train_network.py +++ b/train_network.py @@ -37,6 +37,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder + if args.use_dadaptation_optimizer: # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] + return logs