change async uploading to optional

This commit is contained in:
ddPn08
2023-04-02 17:45:26 +09:00
parent 8bfa50e283
commit 16ba1cec69
5 changed files with 38 additions and 26 deletions

View File

@@ -1909,6 +1909,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--async_upload",
action="store_true",
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
)
parser.add_argument(
"--save_precision",
type=str,
@@ -2831,7 +2836,7 @@ def save_sd_model_on_epoch_end(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
@@ -2852,7 +2857,7 @@ def save_sd_model_on_epoch_end(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)
huggingface_util.upload(args, out_dir, "/" + model_name)
def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
@@ -2873,7 +2878,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
@@ -2909,7 +2914,7 @@ def save_sd_model_on_train_end(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else:
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
@@ -2919,7 +2924,7 @@ def save_sd_model_on_train_end(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def save_state_on_train_end(args: argparse.Namespace, accelerator):