mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
83 lines
3.1 KiB
Python
83 lines
3.1 KiB
Python
from typing import Union, BinaryIO
|
||
from huggingface_hub import HfApi
|
||
from pathlib import Path
|
||
import argparse
|
||
import os
|
||
from library.utils import fire_in_thread
|
||
from library.utils import setup_logging
|
||
|
||
setup_logging()
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||
api = HfApi(
|
||
token=token,
|
||
)
|
||
try:
|
||
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||
return True
|
||
except:
|
||
return False
|
||
|
||
|
||
def upload(
|
||
args: argparse.Namespace,
|
||
src: Union[str, Path, bytes, BinaryIO],
|
||
dest_suffix: str = "",
|
||
force_sync_upload: bool = False,
|
||
):
|
||
repo_id = args.huggingface_repo_id
|
||
repo_type = args.huggingface_repo_type
|
||
token = args.huggingface_token
|
||
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
||
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||
api = HfApi(token=token)
|
||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||
try:
|
||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||
logger.error("===========================================")
|
||
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||
logger.error("===========================================")
|
||
|
||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||
|
||
def uploader():
|
||
try:
|
||
# 自前でスレッド化しているので run_as_future は明示的に False にする(Hub APIのバグかもしれない)
|
||
if is_folder:
|
||
api.upload_folder(
|
||
repo_id=repo_id, repo_type=repo_type, folder_path=src, path_in_repo=path_in_repo, run_as_future=False
|
||
)
|
||
else:
|
||
api.upload_file(
|
||
repo_id=repo_id, repo_type=repo_type, path_or_fileobj=src, path_in_repo=path_in_repo, run_as_future=False
|
||
)
|
||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||
logger.error("===========================================")
|
||
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||
logger.error("===========================================")
|
||
|
||
if args.async_upload and not force_sync_upload:
|
||
fire_in_thread(uploader)
|
||
else:
|
||
uploader()
|
||
|
||
|
||
def list_dir(
|
||
repo_id: str,
|
||
subfolder: str,
|
||
repo_type: str,
|
||
revision: str = "main",
|
||
token: str = None,
|
||
):
|
||
api = HfApi(
|
||
token=token,
|
||
)
|
||
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
||
return file_list
|