mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
resume from huggingface repository
This commit is contained in:
@@ -231,9 +231,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
71
library/huggingface_util.py
Normal file
71
library/huggingface_util.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from typing import *
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from library.utils import fire_in_thread
|
||||||
|
|
||||||
|
|
||||||
|
def exists_repo(
|
||||||
|
repo_id: str, repo_type: str, revision: str = "main", token: str = None
|
||||||
|
):
|
||||||
|
api = HfApi(
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@fire_in_thread
|
||||||
|
def upload(
|
||||||
|
src: Union[str, Path, bytes, BinaryIO],
|
||||||
|
args: argparse.Namespace,
|
||||||
|
dest_suffix: str = "",
|
||||||
|
):
|
||||||
|
repo_id = args.huggingface_repo_id
|
||||||
|
repo_type = args.huggingface_repo_type
|
||||||
|
token = args.huggingface_token
|
||||||
|
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
||||||
|
private = args.huggingface_repo_visibility == "private"
|
||||||
|
api = HfApi(token=token)
|
||||||
|
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||||
|
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||||
|
|
||||||
|
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||||
|
isinstance(src, Path) and src.is_dir()
|
||||||
|
)
|
||||||
|
if is_folder:
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
folder_path=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api.upload_file(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
path_or_fileobj=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_dir(
|
||||||
|
repo_id: str,
|
||||||
|
subfolder: str,
|
||||||
|
repo_type: str,
|
||||||
|
revision: str = "main",
|
||||||
|
token: str = None,
|
||||||
|
):
|
||||||
|
api = HfApi(
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||||
|
file_list = [
|
||||||
|
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
|
||||||
|
]
|
||||||
|
return file_list
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -49,6 +50,7 @@ from diffusers import (
|
|||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
)
|
)
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
import albumentations as albu
|
import albumentations as albu
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -58,7 +60,7 @@ from torch import einsum
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.utils as utils
|
import library.huggingface_util as huggingface_util
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
@@ -1902,6 +1904,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token to upload model / huggingfaceにアップロードするモデルのトークン")
|
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token to upload model / huggingfaceにアップロードするモデルのトークン")
|
||||||
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface model visibility / huggingfaceにアップロードするモデルの公開設定")
|
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface model visibility / huggingfaceにアップロードするモデルの公開設定")
|
||||||
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
|
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_huggingface",
|
||||||
|
action="store_true",
|
||||||
|
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_precision",
|
"--save_precision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -2266,6 +2273,56 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
|||||||
|
|
||||||
# region utils
|
# region utils
|
||||||
|
|
||||||
|
def resume(accelerator, args):
|
||||||
|
if args.resume:
|
||||||
|
print(f"resume training from state: {args.resume}")
|
||||||
|
if args.resume_from_huggingface:
|
||||||
|
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
||||||
|
path_in_repo = "/".join(args.resume.split("/")[2:])
|
||||||
|
revision = None
|
||||||
|
repo_type = None
|
||||||
|
if ":" in path_in_repo:
|
||||||
|
divided = path_in_repo.split(":")
|
||||||
|
if len(divided) == 2:
|
||||||
|
path_in_repo, revision = divided
|
||||||
|
repo_type = "model"
|
||||||
|
else:
|
||||||
|
path_in_repo, revision, repo_type = divided
|
||||||
|
print(
|
||||||
|
f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}"
|
||||||
|
)
|
||||||
|
|
||||||
|
list_files = huggingface_util.list_dir(
|
||||||
|
repo_id=repo_id,
|
||||||
|
subfolder=path_in_repo,
|
||||||
|
revision=revision,
|
||||||
|
token=args.huggingface_token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download(filename) -> str:
|
||||||
|
def task():
|
||||||
|
return hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
revision=revision,
|
||||||
|
repo_type=repo_type,
|
||||||
|
token=args.huggingface_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, task)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
results = loop.run_until_complete(
|
||||||
|
asyncio.gather(
|
||||||
|
*[download(filename=filename.rfilename) for filename in list_files]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dirname = os.path.dirname(results[0])
|
||||||
|
accelerator.load_state(dirname)
|
||||||
|
else:
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, trainable_params):
|
def get_optimizer(args, trainable_params):
|
||||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||||
@@ -2812,7 +2869,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
accelerator.save_state(state_dir)
|
accelerator.save_state(state_dir)
|
||||||
if args.save_state_to_huggingface:
|
if args.save_state_to_huggingface:
|
||||||
utils.huggingface_upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
|
||||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
if last_n_epochs is not None:
|
if last_n_epochs is not None:
|
||||||
|
|||||||
@@ -1,66 +1,8 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import threading
|
import threading
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
|
|
||||||
|
|
||||||
def fire_in_thread(f):
|
def fire_in_thread(f):
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def huggingface_exists_repo(
|
|
||||||
repo_id: str, repo_type: str, revision: str = "main", token: str = None
|
|
||||||
):
|
|
||||||
api = HfApi()
|
|
||||||
try:
|
|
||||||
api.repo_info(
|
|
||||||
repo_id=repo_id, token=token, revision=revision, repo_type=repo_type
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@fire_in_thread
|
|
||||||
def huggingface_upload(
|
|
||||||
src: Union[str, Path, bytes, BinaryIO],
|
|
||||||
args: argparse.Namespace,
|
|
||||||
dest_suffix: str = "",
|
|
||||||
):
|
|
||||||
repo_id = args.huggingface_repo_id
|
|
||||||
repo_type = args.huggingface_repo_type
|
|
||||||
token = args.huggingface_token
|
|
||||||
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
|
||||||
private = args.huggingface_repo_visibility == "private"
|
|
||||||
api = HfApi()
|
|
||||||
if not huggingface_exists_repo(
|
|
||||||
repo_id=repo_id, repo_type=repo_type, token=token
|
|
||||||
):
|
|
||||||
api.create_repo(
|
|
||||||
token=token, repo_id=repo_id, repo_type=repo_type, private=private
|
|
||||||
)
|
|
||||||
|
|
||||||
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
|
||||||
isinstance(src, Path) and src.is_dir()
|
|
||||||
)
|
|
||||||
if is_folder:
|
|
||||||
api.upload_folder(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type=repo_type,
|
|
||||||
folder_path=src,
|
|
||||||
path_in_repo=path_in_repo,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
api.upload_file(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type=repo_type,
|
|
||||||
path_or_fileobj=src,
|
|
||||||
path_in_repo=path_in_repo,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
@@ -21,6 +21,6 @@ fairscale==0.4.13
|
|||||||
# for WD14 captioning
|
# for WD14 captioning
|
||||||
# tensorflow<2.11
|
# tensorflow<2.11
|
||||||
tensorflow==2.10.1
|
tensorflow==2.10.1
|
||||||
huggingface-hub==0.12.0
|
huggingface-hub==0.13.3
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
.
|
||||||
|
|||||||
@@ -202,9 +202,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from library.config_util import (
|
|||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.utils as utils
|
import library.huggingface_util as huggingface_util
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
@@ -285,9 +285,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
@@ -628,7 +626,7 @@ def train(args):
|
|||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -668,7 +666,7 @@ def train(args):
|
|||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -304,9 +304,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user