mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support separate learning rates for TE1/2
This commit is contained in:
@@ -10,10 +10,13 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -272,10 +275,11 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
training_models = []
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
training_models.append(unet)
|
train_unet = args.learning_rate > 0
|
||||||
|
train_text_encoder1 = False
|
||||||
|
train_text_encoder2 = False
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
# TODO each option for two text encoders?
|
# TODO each option for two text encoders?
|
||||||
@@ -283,9 +287,20 @@ def train(args):
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
text_encoder1.gradient_checkpointing_enable()
|
text_encoder1.gradient_checkpointing_enable()
|
||||||
text_encoder2.gradient_checkpointing_enable()
|
text_encoder2.gradient_checkpointing_enable()
|
||||||
training_models.append(text_encoder1)
|
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||||
training_models.append(text_encoder2)
|
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
||||||
# set require_grad=True later
|
train_text_encoder1 = lr_te1 > 0
|
||||||
|
train_text_encoder2 = lr_te2 > 0
|
||||||
|
|
||||||
|
# caching one text encoder output is not supported
|
||||||
|
if not train_text_encoder1:
|
||||||
|
text_encoder1.to(weight_dtype)
|
||||||
|
if not train_text_encoder2:
|
||||||
|
text_encoder2.to(weight_dtype)
|
||||||
|
text_encoder1.requires_grad_(train_text_encoder1)
|
||||||
|
text_encoder2.requires_grad_(train_text_encoder2)
|
||||||
|
text_encoder1.train(train_text_encoder1)
|
||||||
|
text_encoder2.train(train_text_encoder2)
|
||||||
else:
|
else:
|
||||||
text_encoder1.to(weight_dtype)
|
text_encoder1.to(weight_dtype)
|
||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
@@ -313,21 +328,25 @@ def train(args):
|
|||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
for m in training_models:
|
unet.requires_grad_(train_unet)
|
||||||
m.requires_grad_(True)
|
if not train_unet:
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||||
|
|
||||||
if block_lrs is None:
|
training_models = []
|
||||||
params_to_optimize = [
|
params_to_optimize = []
|
||||||
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
|
if train_unet:
|
||||||
]
|
training_models.append(unet)
|
||||||
else:
|
if block_lrs is None:
|
||||||
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
|
||||||
|
else:
|
||||||
|
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
|
||||||
|
|
||||||
for m in training_models[1:]: # Text Encoders if exists
|
if train_text_encoder1:
|
||||||
params_to_optimize.append({
|
training_models.append(text_encoder1)
|
||||||
"params": list(m.parameters()),
|
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||||
"lr": args.learning_rate_te or args.learning_rate
|
if train_text_encoder2:
|
||||||
})
|
training_models.append(text_encoder2)
|
||||||
|
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
# calculate number of trainable parameters
|
||||||
n_params = 0
|
n_params = 0
|
||||||
@@ -335,6 +354,7 @@ def train(args):
|
|||||||
for p in params["params"]:
|
for p in params["params"]:
|
||||||
n_params += p.numel()
|
n_params += p.numel()
|
||||||
|
|
||||||
|
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
||||||
accelerator.print(f"number of models: {len(training_models)}")
|
accelerator.print(f"number of models: {len(training_models)}")
|
||||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||||
|
|
||||||
@@ -386,16 +406,17 @@ def train(args):
|
|||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.train_text_encoder:
|
if train_unet:
|
||||||
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet = accelerator.prepare(unet)
|
||||||
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform DDP after prepare
|
|
||||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
|
||||||
else:
|
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
|
||||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||||
|
if train_text_encoder1:
|
||||||
|
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||||
|
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
|
||||||
|
if train_text_encoder2:
|
||||||
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
|
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])
|
||||||
|
|
||||||
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
@@ -461,7 +482,7 @@ def train(args):
|
|||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(*training_models):
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
else:
|
else:
|
||||||
@@ -547,7 +568,12 @@ def train(args):
|
|||||||
|
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
|
if (
|
||||||
|
args.min_snr_gamma
|
||||||
|
or args.scale_v_pred_loss_like_noise_pred
|
||||||
|
or args.v_pred_like_loss
|
||||||
|
or args.debiased_estimation_loss
|
||||||
|
):
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -725,7 +751,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate_te1",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate_te2",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
|||||||
Reference in New Issue
Block a user