mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Added feature to upload to huggingface
This commit is contained in:
@@ -58,6 +58,7 @@ from torch import einsum
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
import library.utils as utils
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
@@ -1441,7 +1442,6 @@ def glob_images_pathlib(dir_path, recursive):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
"""
|
"""
|
||||||
高速化のためのモジュール入れ替え
|
高速化のためのモジュール入れ替え
|
||||||
@@ -1896,6 +1896,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
||||||
|
parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload model / huggingfaceにアップロードするモデルのリポジトリ名")
|
||||||
|
parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload model / huggingfaceにアップロードするモデルのリポジトリの種類")
|
||||||
|
parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload model / huggingfaceにアップロードするモデルのパス")
|
||||||
|
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token to upload model / huggingfaceにアップロードするモデルのトークン")
|
||||||
|
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface model visibility / huggingfaceにアップロードするモデルの公開設定")
|
||||||
|
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_precision",
|
"--save_precision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -2803,7 +2809,10 @@ def save_sd_model_on_epoch_end(
|
|||||||
|
|
||||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||||
print("saving state.")
|
print("saving state.")
|
||||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
accelerator.save_state(state_dir)
|
||||||
|
if args.save_state_to_huggingface:
|
||||||
|
utils.huggingface_upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
|
||||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
if last_n_epochs is not None:
|
if last_n_epochs is not None:
|
||||||
|
|||||||
64
library/utils.py
Normal file
64
library/utils.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import threading
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
|
||||||
|
def fire_in_thread(f):
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def huggingface_exists_repo(
|
||||||
|
repo_id: str, repo_type: str, revision: str = "main", hf_token: str = None
|
||||||
|
):
|
||||||
|
api = HfApi()
|
||||||
|
try:
|
||||||
|
api.repo_info(
|
||||||
|
repo_id=repo_id, token=hf_token, revision=revision, repo_type=repo_type
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@fire_in_thread
|
||||||
|
def huggingface_upload(
|
||||||
|
src: Union[str, Path, bytes, BinaryIO],
|
||||||
|
args: argparse.Namespace,
|
||||||
|
dest_suffix: str = "",
|
||||||
|
):
|
||||||
|
repo_id = args.huggingface_repo_id
|
||||||
|
repo_type = args.huggingface_repo_type
|
||||||
|
hf_token = args.huggingface_token
|
||||||
|
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
||||||
|
private = args.huggingface_repo_visibility == "private"
|
||||||
|
api = HfApi()
|
||||||
|
if not huggingface_exists_repo(
|
||||||
|
repo_id=repo_id, repo_type=repo_type, hf_token=hf_token
|
||||||
|
):
|
||||||
|
api.create_repo(
|
||||||
|
token=hf_token, repo_id=repo_id, repo_type=repo_type, private=private
|
||||||
|
)
|
||||||
|
|
||||||
|
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||||
|
isinstance(src, Path) and src.is_dir()
|
||||||
|
)
|
||||||
|
if is_folder:
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
folder_path=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api.upload_file(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
path_or_fileobj=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
@@ -24,6 +24,7 @@ from library.config_util import (
|
|||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
import library.utils as utils
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
@@ -626,6 +627,7 @@ def train(args):
|
|||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
|
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -665,6 +667,7 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
|
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user