From 3cc4939dd38d52a077b97e53260631b4da755628 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Sat, 1 Apr 2023 23:16:02 +0900 Subject: [PATCH] Implement huggingface upload for all scripts --- library/train_util.py | 2 ++ train_textual_inversion.py | 3 +++ train_textual_inversion_XTI.py | 3 +++ 3 files changed, 8 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index e4e91ee2..4b9e3ec4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2830,6 +2830,8 @@ def save_sd_model_on_epoch_end( model_util.save_stable_diffusion_checkpoint( args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae ) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_sd(old_epoch_no): _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index c5bacf3b..c4b04554 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -13,6 +13,7 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util +import library.huggingface_util as huggingface_util import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -450,6 +451,8 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 74e9bc2e..58c79142 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -13,6 +13,7 @@ import diffusers from diffusers import DDPMScheduler import library.train_util as train_util +import library.huggingface_util as huggingface_util import library.config_util as config_util from library.config_util import ( ConfigSanitizer, @@ -493,6 +494,8 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as