support separate learning rates for TE1/2

This commit is contained in:
Kohya S
2023-10-29 20:38:01 +09:00
parent a9ed4ed8a8
commit 01d929ee2a

View File

@@ -10,10 +10,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -272,10 +275,11 @@ def train(args):
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
train_unet = args.learning_rate > 0
train_text_encoder1 = False
train_text_encoder2 = False
if args.train_text_encoder:
# TODO each option for two text encoders?
@@ -283,9 +287,20 @@ def train(args):
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
# set require_grad=True later
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0
# caching one text encoder output is not supported
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
if not train_text_encoder2:
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder2.requires_grad_(train_text_encoder2)
text_encoder1.train(train_text_encoder1)
text_encoder2.train(train_text_encoder2)
else:
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
@@ -313,21 +328,25 @@ def train(args):
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
unet.requires_grad_(train_unet)
if not train_unet:
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
if block_lrs is None:
params_to_optimize = [
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
]
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
training_models = []
params_to_optimize = []
if train_unet:
training_models.append(unet)
if block_lrs is None:
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
else:
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({
"params": list(m.parameters()),
"lr": args.learning_rate_te or args.learning_rate
})
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
if train_text_encoder2:
training_models.append(text_encoder2)
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
@@ -335,6 +354,7 @@ def train(args):
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
@@ -386,16 +406,17 @@ def train(args):
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if train_unet:
unet = accelerator.prepare(unet)
(unet,) = train_util.transform_models_if_DDP([unet])
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -461,7 +482,7 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
@@ -547,7 +568,12 @@ def train(args):
target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
@@ -725,7 +751,19 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
)
parser.add_argument(
"--learning_rate_te2",
type=float,
default=None,
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")