From 80ef59c115a5282f26dafe5ff628e141a4df5950 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Feb 2024 09:12:37 +0900 Subject: [PATCH] support text encoder training in stable cascade --- library/stable_cascade_utils.py | 34 +++++++++++------ stable_cascade_train_stage_c.py | 67 +++++++++++++++++++++++++++------ 2 files changed, 78 insertions(+), 23 deletions(-) diff --git a/library/stable_cascade_utils.py b/library/stable_cascade_utils.py index 2406913d..cf971097 100644 --- a/library/stable_cascade_utils.py +++ b/library/stable_cascade_utils.py @@ -350,6 +350,24 @@ def get_sai_model_spec(args): return metadata +def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata): + state_dict = stage_c.state_dict() + if save_dtype is not None: + state_dict = {k: v.to(save_dtype) for k, v in state_dict.items} + + save_file(state_dict, ckpt_file, metadata=sai_metadata) + + # save text model + if text_model is not None: + text_model_sd = text_model.state_dict() + + if save_dtype is not None: + text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()} + + text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors" + save_file(text_model_sd, text_model_ckpt_file) + + def save_stage_c_model_on_epoch_end_or_stepwise( args: argparse.Namespace, on_epoch_end: bool, @@ -359,15 +377,11 @@ def save_stage_c_model_on_epoch_end_or_stepwise( num_train_epochs: int, global_step: int, stage_c, + text_model, ): def stage_c_saver(ckpt_file, epoch_no, global_step): sai_metadata = get_sai_model_spec(args) - - state_dict = stage_c.state_dict() - if save_dtype is not None: - state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} - - save_file(state_dict, ckpt_file, metadata=sai_metadata) + stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata) save_sd_model_on_epoch_end_or_stepwise_common( args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None @@ -380,15 +394,11 @@ def save_stage_c_model_on_end( epoch: int, global_step: int, stage_c, + text_model, ): def stage_c_saver(ckpt_file, epoch_no, global_step): sai_metadata = get_sai_model_spec(args) - - state_dict = stage_c.state_dict() - if save_dtype is not None: - state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} - - save_file(state_dict, ckpt_file, metadata=sai_metadata) + stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata) save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None) diff --git a/stable_cascade_train_stage_c.py b/stable_cascade_train_stage_c.py index d5b4d5c0..fc434031 100644 --- a/stable_cascade_train_stage_c.py +++ b/stable_cascade_train_stage_c.py @@ -142,7 +142,6 @@ def train(args): # 学習を準備する if cache_latents: - raise NotImplementedError("Caching latents is not supported in this version / latentのキャッシュはサポートされていません") logger.info( "Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`." + " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。" @@ -168,9 +167,26 @@ def train(args): logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.") # stage_c.enable_gradient_checkpointing() - text_encoder1.to(weight_dtype) - text_encoder1.requires_grad_(False) - text_encoder1.eval() + train_stage_c = args.learning_rate > 0 + train_text_encoder1 = False + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder1.gradient_checkpointing_enable() + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + train_text_encoder1 = lr_te1 > 0 + assert train_text_encoder1, "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。" + + # caching one text encoder output is not supported + if not train_text_encoder1: + text_encoder1.to(weight_dtype) + text_encoder1.requires_grad_(train_text_encoder1) + text_encoder1.train(train_text_encoder1) + else: + text_encoder1.to(weight_dtype) + text_encoder1.requires_grad_(False) + text_encoder1.eval() # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: @@ -199,11 +215,18 @@ def train(args): effnet.to(accelerator.device, dtype=effnet_dtype) stage_c.requires_grad_(True) + if not train_stage_c: + stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared training_models = [] params_to_optimize = [] - training_models.append(stage_c) - params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate}) + if train_stage_c: + training_models.append(stage_c) + params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate}) + + if train_text_encoder1: + training_models.append(text_encoder1) + params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) # calculate number of trainable parameters n_params = 0 @@ -211,6 +234,7 @@ def train(args): for p in params["params"]: n_params += p.numel() + accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -262,7 +286,10 @@ def train(args): text_encoder1.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - stage_c = accelerator.prepare(stage_c) + if train_stage_c: + stage_c = accelerator.prepare(stage_c) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) @@ -372,7 +399,7 @@ def train(args): if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: input_ids1 = batch["input_ids"] - with torch.no_grad(): + with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning # TODO support weighted captions input_ids1 = input_ids1.to(accelerator.device) @@ -436,7 +463,8 @@ def train(args): epoch, num_train_epochs, global_step, - accelerator.accelerator.unwrap_model(stage_c), + accelerator.unwrap_model(stage_c), + accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None, ) current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず @@ -463,7 +491,15 @@ def train(args): if args.save_every_n_epochs is not None: if accelerator.is_main_process: sc_utils.save_stage_c_model_on_epoch_end_or_stepwise( - args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(stage_c) + args, + True, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(stage_c), + accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None, ) # sdxl_train_util.sample_images( @@ -481,6 +517,7 @@ def train(args): is_main_process = accelerator.is_main_process # if is_main_process: stage_c = accelerator.unwrap_model(stage_c) + text_encoder1 = accelerator.unwrap_model(text_encoder1) accelerator.end_training() @@ -490,7 +527,9 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - sc_utils.save_stage_c_model_on_end(args, save_dtype, epoch, global_step, stage_c) + sc_utils.save_stage_c_model_on_end( + args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None + ) logger.info("model saved.") @@ -508,6 +547,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) add_sdxl_training_arguments(parser) # cache text encoder outputs + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder / text encoderの学習率", + ) parser.add_argument( "--no_half_vae", action="store_true",