mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
small fix
This commit is contained in:
@@ -2851,6 +2851,8 @@ def save_sd_model_on_epoch_end(
|
|||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(out_dir, args, "/" + model_name)
|
||||||
|
|
||||||
def remove_du(old_epoch_no):
|
def remove_du(old_epoch_no):
|
||||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||||
@@ -2906,6 +2908,8 @@ def save_sd_model_on_train_end(
|
|||||||
model_util.save_stable_diffusion_checkpoint(
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(args.output_dir, model_name)
|
out_dir = os.path.join(args.output_dir, model_name)
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
@@ -2914,6 +2918,8 @@ def save_sd_model_on_train_end(
|
|||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
)
|
)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(out_dir, args, "/" + model_name)
|
||||||
|
|
||||||
|
|
||||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||||
|
|||||||
@@ -493,6 +493,8 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -537,6 +537,8 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
if args.huggingface_repo_id is not None:
|
||||||
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user