mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Implement huggingface upload for all scripts
This commit is contained in:
@@ -2830,6 +2830,8 @@ def save_sd_model_on_epoch_end(
|
|||||||
model_util.save_stable_diffusion_checkpoint(
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_sd(old_epoch_no):
|
def remove_sd(old_epoch_no):
|
||||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import diffusers
|
|||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
import library.huggingface_util as huggingface_util
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
@@ -450,6 +451,8 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import diffusers
|
|||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
import library.huggingface_util as huggingface_util
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
@@ -493,6 +494,8 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
|
|||||||
Reference in New Issue
Block a user