From d42431d73a2eafb60446295ef52c6628133d2ad5 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Tue, 28 Mar 2023 00:49:09 +0900 Subject: [PATCH] Added feature to upload to huggingface --- library/train_util.py | 13 +++++++-- library/utils.py | 64 +++++++++++++++++++++++++++++++++++++++++++ train_network.py | 3 ++ 3 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 library/utils.py diff --git a/library/train_util.py b/library/train_util.py index 59dbc44c..179f23e4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -58,6 +58,7 @@ from torch import einsum import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util +import library.utils as utils # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1441,7 +1442,6 @@ def glob_images_pathlib(dir_path, recursive): # endregion - # region モジュール入れ替え部 """ 高速化のためのモジュール入れ替え @@ -1896,6 +1896,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") + parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload model / huggingfaceにアップロードするモデルのリポジトリ名") + parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload model / huggingfaceにアップロードするモデルのリポジトリの種類") + parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload model / huggingfaceにアップロードするモデルのパス") + parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token to upload model / huggingfaceにアップロードするモデルのトークン") + parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface model visibility / huggingfaceにアップロードするモデルの公開設定") + parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する") parser.add_argument( "--save_precision", type=str, @@ -2803,7 +2809,10 @@ def save_sd_model_on_epoch_end( def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + utils.huggingface_upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs if last_n_epochs is not None: diff --git a/library/utils.py b/library/utils.py new file mode 100644 index 00000000..68c51fc3 --- /dev/null +++ b/library/utils.py @@ -0,0 +1,64 @@ +import argparse +import os +from pathlib import Path +import threading +from typing import * + +from huggingface_hub import HfApi + + +def fire_in_thread(f): + def wrapped(*args, **kwargs): + threading.Thread(target=f, args=args, kwargs=kwargs).start() + return wrapped + + +def huggingface_exists_repo( + repo_id: str, repo_type: str, revision: str = "main", hf_token: str = None +): + api = HfApi() + try: + api.repo_info( + repo_id=repo_id, token=hf_token, revision=revision, repo_type=repo_type + ) + return True + except: + return False + + +@fire_in_thread +def huggingface_upload( + src: Union[str, Path, bytes, BinaryIO], + args: argparse.Namespace, + dest_suffix: str = "", +): + repo_id = args.huggingface_repo_id + repo_type = args.huggingface_repo_type + hf_token = args.huggingface_token + path_in_repo = args.huggingface_path_in_repo + dest_suffix + private = args.huggingface_repo_visibility == "private" + api = HfApi() + if not huggingface_exists_repo( + repo_id=repo_id, repo_type=repo_type, hf_token=hf_token + ): + api.create_repo( + token=hf_token, repo_id=repo_id, repo_type=repo_type, private=private + ) + + is_folder = (type(src) == str and os.path.isdir(src)) or ( + isinstance(src, Path) and src.is_dir() + ) + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) diff --git a/train_network.py b/train_network.py index 2b824018..b641e65c 100644 --- a/train_network.py +++ b/train_network.py @@ -24,6 +24,7 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.utils as utils import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight @@ -626,6 +627,7 @@ def train(args): metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as @@ -665,6 +667,7 @@ def train(args): print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name) print("model saved.")