small fix

This commit is contained in:
ddPn08
2023-04-02 00:10:19 +09:00
parent c4a11e5a5a
commit 8bfa50e283
3 changed files with 10 additions and 0 deletions

View File

@@ -2851,6 +2851,8 @@ def save_sd_model_on_epoch_end(
model_util.save_diffusers_checkpoint( model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)
def remove_du(old_epoch_no): def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
@@ -2906,6 +2908,8 @@ def save_sd_model_on_train_end(
model_util.save_stable_diffusion_checkpoint( model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
else: else:
out_dir = os.path.join(args.output_dir, model_name) out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
@@ -2914,6 +2918,8 @@ def save_sd_model_on_train_end(
model_util.save_diffusers_checkpoint( model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
) )
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)
def save_state_on_train_end(args: argparse.Namespace, accelerator): def save_state_on_train_end(args: argparse.Namespace, accelerator):

View File

@@ -493,6 +493,8 @@ def train(args):
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
print("model saved.") print("model saved.")

View File

@@ -537,6 +537,8 @@ def train(args):
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
print("model saved.") print("model saved.")