diff --git a/library/huggingface_util.py b/library/huggingface_util.py index 353189c0..4431a208 100644 --- a/library/huggingface_util.py +++ b/library/huggingface_util.py @@ -20,11 +20,11 @@ def exists_repo( return False -@fire_in_thread def upload( - src: Union[str, Path, bytes, BinaryIO], args: argparse.Namespace, + src: Union[str, Path, bytes, BinaryIO], dest_suffix: str = "", + force_sync_upload: bool = False, ): repo_id = args.huggingface_repo_id repo_type = args.huggingface_repo_type @@ -38,20 +38,27 @@ def upload( is_folder = (type(src) == str and os.path.isdir(src)) or ( isinstance(src, Path) and src.is_dir() ) - if is_folder: - api.upload_folder( - repo_id=repo_id, - repo_type=repo_type, - folder_path=src, - path_in_repo=path_in_repo, - ) + + def uploader(): + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + + if args.async_upload and not force_sync_upload: + fire_in_thread(uploader) else: - api.upload_file( - repo_id=repo_id, - repo_type=repo_type, - path_or_fileobj=src, - path_in_repo=path_in_repo, - ) + uploader() def list_dir( diff --git a/library/train_util.py b/library/train_util.py index c6d49974..425159c2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1909,6 +1909,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", ) + parser.add_argument( + "--async_upload", + action="store_true", + help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", + ) parser.add_argument( "--save_precision", type=str, @@ -2831,7 +2836,7 @@ def save_sd_model_on_epoch_end( args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae ) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_sd(old_epoch_no): _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) @@ -2852,7 +2857,7 @@ def save_sd_model_on_epoch_end( args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors ) if args.huggingface_repo_id is not None: - huggingface_util.upload(out_dir, args, "/" + model_name) + huggingface_util.upload(args, out_dir, "/" + model_name) def remove_du(old_epoch_no): out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) @@ -2873,7 +2878,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) accelerator.save_state(state_dir) if args.save_state_to_huggingface: - huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) + huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs if last_n_epochs is not None: @@ -2909,7 +2914,7 @@ def save_sd_model_on_train_end( args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae ) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) else: out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) @@ -2919,7 +2924,7 @@ def save_sd_model_on_train_end( args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors ) if args.huggingface_repo_id is not None: - huggingface_util.upload(out_dir, args, "/" + model_name) + huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) def save_state_on_train_end(args: argparse.Namespace, accelerator): diff --git a/train_network.py b/train_network.py index 85b01def..dc890b99 100644 --- a/train_network.py +++ b/train_network.py @@ -627,7 +627,7 @@ def train(args): print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -668,7 +668,7 @@ def train(args): print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.") diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 3fb17f2e..e7d052ee 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -452,7 +452,7 @@ def train(args): print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -494,7 +494,7 @@ def train(args): print(f"save trained model to {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.") diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index b0bc4c3a..7e393bcd 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -495,7 +495,7 @@ def train(args): print(f"saving checkpoint: {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -538,7 +538,7 @@ def train(args): print(f"save trained model to {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) if args.huggingface_repo_id is not None: - huggingface_util.upload(ckpt_file, args, "/" + ckpt_name) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) print("model saved.")