diff --git a/train_network.py b/train_network.py index 8cfe1ab8..85b01def 100644 --- a/train_network.py +++ b/train_network.py @@ -626,7 +626,8 @@ def train(args): metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -666,7 +667,8 @@ def train(args): print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) print("model saved.")