mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support separate learning rates for TE1/2
This commit is contained in:
@@ -10,10 +10,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -272,10 +275,11 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
training_models = []
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
training_models.append(unet)
|
||||
train_unet = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
train_text_encoder2 = False
|
||||
|
||||
if args.train_text_encoder:
|
||||
# TODO each option for two text encoders?
|
||||
@@ -283,9 +287,20 @@ def train(args):
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
text_encoder2.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder1)
|
||||
training_models.append(text_encoder2)
|
||||
# set require_grad=True later
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
train_text_encoder2 = lr_te2 > 0
|
||||
|
||||
# caching one text encoder output is not supported
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
if not train_text_encoder2:
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
text_encoder2.requires_grad_(train_text_encoder2)
|
||||
text_encoder1.train(train_text_encoder1)
|
||||
text_encoder2.train(train_text_encoder2)
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
@@ -313,21 +328,25 @@ def train(args):
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
for m in training_models:
|
||||
m.requires_grad_(True)
|
||||
unet.requires_grad_(train_unet)
|
||||
if not train_unet:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||
|
||||
if block_lrs is None:
|
||||
params_to_optimize = [
|
||||
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
|
||||
]
|
||||
else:
|
||||
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
if train_unet:
|
||||
training_models.append(unet)
|
||||
if block_lrs is None:
|
||||
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
|
||||
else:
|
||||
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
|
||||
|
||||
for m in training_models[1:]: # Text Encoders if exists
|
||||
params_to_optimize.append({
|
||||
"params": list(m.parameters()),
|
||||
"lr": args.learning_rate_te or args.learning_rate
|
||||
})
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
if train_text_encoder2:
|
||||
training_models.append(text_encoder2)
|
||||
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
@@ -335,6 +354,7 @@ def train(args):
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
@@ -386,16 +406,17 @@ def train(args):
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -461,7 +482,7 @@ def train(args):
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
@@ -547,7 +568,12 @@ def train(args):
|
||||
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
|
||||
if (
|
||||
args.min_snr_gamma
|
||||
or args.scale_v_pred_loss_like_noise_pred
|
||||
or args.v_pred_like_loss
|
||||
or args.debiased_estimation_loss
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -725,7 +751,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate_te2",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
|
||||
)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
|
||||
Reference in New Issue
Block a user