From 96d877be9076097419a6d9c0501b2370102a8aed Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Oct 2023 21:30:32 +0900 Subject: [PATCH] support separate LR for Text Encoder for SD1/2 --- fine_tune.py | 27 +++++++++++++++++++++------ train_db.py | 21 ++++++++++++++++++--- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 893066f7..a86a483a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,10 +10,13 @@ import toml from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -193,14 +196,20 @@ def train(args): for m in training_models: m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + + trainable_params = [] + if args.learning_rate_te is None or not args.train_text_encoder: + for m in training_models: + trainable_params.extend(m.parameters()) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -340,7 +349,7 @@ def train(args): else: target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) return parser diff --git a/train_db.py b/train_db.py index 59a124a2..fd8e466e 100644 --- a/train_db.py +++ b/train_db.py @@ -11,10 +11,13 @@ import toml from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -164,11 +167,17 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") if train_text_encoder: - # wightout list, adamw8bit is crashed - trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + if args.learning_rate_te is None: + # wightout list, adamw8bit is crashed + trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] else: trainable_params = unet.parameters() - + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する @@ -461,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) parser.add_argument( "--no_token_padding", action="store_true",