mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Implement huggingface upload for all scripts
This commit is contained in:
@@ -2830,6 +2830,8 @@ def save_sd_model_on_epoch_end(
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||
|
||||
def remove_sd(old_epoch_no):
|
||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||
|
||||
@@ -13,6 +13,7 @@ import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
@@ -450,6 +451,8 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||
|
||||
@@ -13,6 +13,7 @@ import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
@@ -493,6 +494,8 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||
|
||||
Reference in New Issue
Block a user