diff --git a/stable_cascade_train_c_network.py b/stable_cascade_train_c_network.py index 67ff55fa..f7efc60c 100644 --- a/stable_cascade_train_c_network.py +++ b/stable_cascade_train_c_network.py @@ -730,8 +730,8 @@ class NetworkTrainer: metadata["ss_network_args"] = json.dumps(net_kwargs) # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path + if args.stage_c_checkpoint_path is not None: + sd_model_name = args.stage_c_checkpoint_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) @@ -992,6 +992,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_tokenizer_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) diff --git a/stable_cascade_train_stage_c.py b/stable_cascade_train_stage_c.py index 325b5d80..de3bfed8 100644 --- a/stable_cascade_train_stage_c.py +++ b/stable_cascade_train_stage_c.py @@ -531,6 +531,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_tokenizer_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_sdxl_training_arguments(parser) # cache text encoder outputs