diff --git a/library/train_util.py b/library/train_util.py index e4e91ee2..4b9e3ec4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2830,6 +2830,8 @@ def save_sd_model_on_epoch_end( model_util.save_stable_diffusion_checkpoint( args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_sd(old_epoch_no): _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index c5bacf3b..c4b04554 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -13,6 +13,7 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util +import library.huggingface_util as huggingface_util import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -450,6 +451,8 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 74e9bc2e..58c79142 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -13,6 +13,7 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util +import library.huggingface_util as huggingface_util import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -493,6 +494,8 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as