mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support separate LR for Text Encoder for SD1/2
This commit is contained in:
25
fine_tune.py
25
fine_tune.py
@@ -10,10 +10,13 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -193,14 +196,20 @@ def train(args):
|
|||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
params = []
|
|
||||||
|
trainable_params = []
|
||||||
|
if args.learning_rate_te is None or not args.train_text_encoder:
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
params.extend(m.parameters())
|
trainable_params.extend(m.parameters())
|
||||||
params_to_optimize = params
|
else:
|
||||||
|
trainable_params = [
|
||||||
|
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
||||||
|
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
||||||
|
]
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
@@ -340,7 +349,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
|
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -476,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate_te",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
15
train_db.py
15
train_db.py
@@ -11,10 +11,13 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -164,8 +167,14 @@ def train(args):
|
|||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
|
if args.learning_rate_te is None:
|
||||||
# wightout list, adamw8bit is crashed
|
# wightout list, adamw8bit is crashed
|
||||||
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
|
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||||
|
else:
|
||||||
|
trainable_params = [
|
||||||
|
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
||||||
|
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
trainable_params = unet.parameters()
|
trainable_params = unet.parameters()
|
||||||
|
|
||||||
@@ -461,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate_te",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_token_padding",
|
"--no_token_padding",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user