From 86503cb945853f9b7eba781eac980f6e6aaac6b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Wed, 21 Feb 2024 19:38:12 +0800 Subject: [PATCH] add save parser and fix lora scripts model name and hash --- stable_cascade_train_c_network.py | 5 +++-- stable_cascade_train_stage_c.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/stable_cascade_train_c_network.py b/stable_cascade_train_c_network.py index 67ff55fa..f7efc60c 100644 --- a/stable_cascade_train_c_network.py +++ b/stable_cascade_train_c_network.py @@ -730,8 +730,8 @@ class NetworkTrainer: metadata["ss_network_args"] = json.dumps(net_kwargs) # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path + if args.stage_c_checkpoint_path is not None: + sd_model_name = args.stage_c_checkpoint_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) @@ -992,6 +992,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_tokenizer_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/stable_cascade_train_stage_c.py b/stable_cascade_train_stage_c.py index 325b5d80..de3bfed8 100644 --- a/stable_cascade_train_stage_c.py +++ b/stable_cascade_train_stage_c.py @@ -531,6 +531,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_tokenizer_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_sdxl_training_arguments(parser) # cache text encoder outputs