support text encoder training in stable cascade

This commit is contained in:
Kohya S
2024-02-18 09:12:37 +09:00
parent 319bbf8057
commit 80ef59c115
2 changed files with 78 additions and 23 deletions

View File

@@ -350,6 +350,24 @@ def get_sai_model_spec(args):
return metadata
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items}
save_file(state_dict, ckpt_file, metadata=sai_metadata)
# save text model
if text_model is not None:
text_model_sd = text_model.state_dict()
if save_dtype is not None:
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
save_file(text_model_sd, text_model_ckpt_file)
def save_stage_c_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
@@ -359,15 +377,11 @@ def save_stage_c_model_on_epoch_end_or_stepwise(
num_train_epochs: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
save_file(state_dict, ckpt_file, metadata=sai_metadata)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_epoch_end_or_stepwise_common(
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
@@ -380,15 +394,11 @@ def save_stage_c_model_on_end(
epoch: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
save_file(state_dict, ckpt_file, metadata=sai_metadata)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)

View File

@@ -142,7 +142,6 @@ def train(args):
# 学習を準備する
if cache_latents:
raise NotImplementedError("Caching latents is not supported in this version / latentのキャッシュはサポートされていません")
logger.info(
"Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`."
+ " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。"
@@ -168,9 +167,26 @@ def train(args):
logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.")
# stage_c.enable_gradient_checkpointing()
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder1.eval()
train_stage_c = args.learning_rate > 0
train_text_encoder1 = False
if args.train_text_encoder:
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
assert train_text_encoder1, "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
# caching one text encoder output is not supported
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder1.train(train_text_encoder1)
else:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder1.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
@@ -199,11 +215,18 @@ def train(args):
effnet.to(accelerator.device, dtype=effnet_dtype)
stage_c.requires_grad_(True)
if not train_stage_c:
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
training_models = []
params_to_optimize = []
training_models.append(stage_c)
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
if train_stage_c:
training_models.append(stage_c)
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
@@ -211,6 +234,7 @@ def train(args):
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
@@ -262,7 +286,10 @@ def train(args):
text_encoder1.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
stage_c = accelerator.prepare(stage_c)
if train_stage_c:
stage_c = accelerator.prepare(stage_c)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
@@ -372,7 +399,7 @@ def train(args):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
with torch.no_grad():
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
input_ids1 = input_ids1.to(accelerator.device)
@@ -436,7 +463,8 @@ def train(args):
epoch,
num_train_epochs,
global_step,
accelerator.accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
@@ -463,7 +491,15 @@ def train(args):
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(stage_c)
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
# sdxl_train_util.sample_images(
@@ -481,6 +517,7 @@ def train(args):
is_main_process = accelerator.is_main_process
# if is_main_process:
stage_c = accelerator.unwrap_model(stage_c)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
accelerator.end_training()
@@ -490,7 +527,9 @@ def train(args):
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
sc_utils.save_stage_c_model_on_end(args, save_dtype, epoch, global_step, stage_c)
sc_utils.save_stage_c_model_on_end(
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
)
logger.info("model saved.")
@@ -508,6 +547,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
add_sdxl_training_arguments(parser) # cache text encoder outputs
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder / text encoderの学習率",
)
parser.add_argument(
"--no_half_vae",
action="store_true",