support separate LR for Text Encoder for SD1/2

This commit is contained in:
Kohya S
2023-10-29 21:30:32 +09:00
parent e72020ae01
commit 96d877be90
2 changed files with 39 additions and 9 deletions

View File

@@ -10,10 +10,13 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -193,14 +196,20 @@ def train(args):
for m in training_models: for m in training_models:
m.requires_grad_(True) m.requires_grad_(True)
params = []
trainable_params = []
if args.learning_rate_te is None or not args.train_text_encoder:
for m in training_models: for m in training_models:
params.extend(m.parameters()) trainable_params.extend(m.parameters())
params_to_optimize = params else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
@@ -340,7 +349,7 @@ def train(args):
else: else:
target = noise target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,: if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss # do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
@@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
return parser return parser

View File

@@ -11,10 +11,13 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -164,8 +167,14 @@ def train(args):
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder: if train_text_encoder:
if args.learning_rate_te is None:
# wightout list, adamw8bit is crashed # wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
else: else:
trainable_params = unet.parameters() trainable_params = unet.parameters()
@@ -461,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser) custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument( parser.add_argument(
"--no_token_padding", "--no_token_padding",
action="store_true", action="store_true",