mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
support text encoder training in stable cascade
This commit is contained in:
@@ -350,6 +350,24 @@ def get_sai_model_spec(args):
|
||||
return metadata
|
||||
|
||||
|
||||
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items}
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
|
||||
# save text model
|
||||
if text_model is not None:
|
||||
text_model_sd = text_model.state_dict()
|
||||
|
||||
if save_dtype is not None:
|
||||
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
|
||||
|
||||
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
|
||||
save_file(text_model_sd, text_model_ckpt_file)
|
||||
|
||||
|
||||
def save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
@@ -359,15 +377,11 @@ def save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
text_model,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
|
||||
|
||||
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
|
||||
@@ -380,15 +394,11 @@ def save_stage_c_model_on_end(
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
text_model,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
|
||||
|
||||
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
|
||||
|
||||
|
||||
@@ -142,7 +142,6 @@ def train(args):
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
raise NotImplementedError("Caching latents is not supported in this version / latentのキャッシュはサポートされていません")
|
||||
logger.info(
|
||||
"Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`."
|
||||
+ " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。"
|
||||
@@ -168,9 +167,26 @@ def train(args):
|
||||
logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.")
|
||||
# stage_c.enable_gradient_checkpointing()
|
||||
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
train_stage_c = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
|
||||
if args.train_text_encoder:
|
||||
accelerator.print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
assert train_text_encoder1, "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
|
||||
|
||||
# caching one text encoder output is not supported
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
text_encoder1.train(train_text_encoder1)
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -199,11 +215,18 @@ def train(args):
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
|
||||
stage_c.requires_grad_(True)
|
||||
if not train_stage_c:
|
||||
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(stage_c)
|
||||
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
|
||||
if train_stage_c:
|
||||
training_models.append(stage_c)
|
||||
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
|
||||
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
@@ -211,6 +234,7 @@ def train(args):
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
@@ -262,7 +286,10 @@ def train(args):
|
||||
text_encoder1.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
stage_c = accelerator.prepare(stage_c)
|
||||
if train_stage_c:
|
||||
stage_c = accelerator.prepare(stage_c)
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
@@ -372,7 +399,7 @@ def train(args):
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
with torch.no_grad():
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
@@ -436,7 +463,8 @@ def train(args):
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
@@ -463,7 +491,15 @@ def train(args):
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(stage_c)
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
@@ -481,6 +517,7 @@ def train(args):
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
text_encoder1 = accelerator.unwrap_model(text_encoder1)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
@@ -490,7 +527,9 @@ def train(args):
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
sc_utils.save_stage_c_model_on_end(args, save_dtype, epoch, global_step, stage_c)
|
||||
sc_utils.save_stage_c_model_on_end(
|
||||
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
@@ -508,6 +547,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
add_sdxl_training_arguments(parser) # cache text encoder outputs
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder / text encoderの学習率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user