mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
change method name, repo is private in default etc
This commit is contained in:
@@ -231,7 +231,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def upload(
|
|||||||
repo_type = args.huggingface_repo_type
|
repo_type = args.huggingface_repo_type
|
||||||
token = args.huggingface_token
|
token = args.huggingface_token
|
||||||
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
||||||
private = args.huggingface_repo_visibility == "private"
|
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||||
api = HfApi(token=token)
|
api = HfApi(token=token)
|
||||||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||||
|
|||||||
@@ -490,7 +490,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
tokens_len = (
|
tokens_len = (
|
||||||
@@ -1898,12 +1898,28 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
||||||
parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名")
|
parser.add_argument(
|
||||||
parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類")
|
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
|
||||||
parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_path_in_repo",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
|
||||||
|
)
|
||||||
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
||||||
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface repository visibility / huggingfaceにアップロードするリポジトリの公開設定")
|
parser.add_argument(
|
||||||
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
|
"--huggingface_repo_visibility",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume_from_huggingface",
|
"--resume_from_huggingface",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -2278,55 +2294,56 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
|||||||
|
|
||||||
# region utils
|
# region utils
|
||||||
|
|
||||||
def resume(accelerator, args):
|
|
||||||
if args.resume:
|
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
if args.resume_from_huggingface:
|
|
||||||
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
|
||||||
path_in_repo = "/".join(args.resume.split("/")[2:])
|
|
||||||
revision = None
|
|
||||||
repo_type = None
|
|
||||||
if ":" in path_in_repo:
|
|
||||||
divided = path_in_repo.split(":")
|
|
||||||
if len(divided) == 2:
|
|
||||||
path_in_repo, revision = divided
|
|
||||||
repo_type = "model"
|
|
||||||
else:
|
|
||||||
path_in_repo, revision, repo_type = divided
|
|
||||||
print(
|
|
||||||
f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}"
|
|
||||||
)
|
|
||||||
|
|
||||||
list_files = huggingface_util.list_dir(
|
def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||||
repo_id=repo_id,
|
if not args.resume:
|
||||||
subfolder=path_in_repo,
|
return
|
||||||
revision=revision,
|
|
||||||
token=args.huggingface_token,
|
|
||||||
repo_type=repo_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def download(filename) -> str:
|
if not args.resume_from_huggingface:
|
||||||
def task():
|
print(f"resume training from local state: {args.resume}")
|
||||||
return hf_hub_download(
|
accelerator.load_state(args.resume)
|
||||||
repo_id=repo_id,
|
return
|
||||||
filename=filename,
|
|
||||||
revision=revision,
|
|
||||||
repo_type=repo_type,
|
|
||||||
token=args.huggingface_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, task)
|
print(f"resume training from huggingface state: {args.resume}")
|
||||||
|
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
||||||
loop = asyncio.get_event_loop()
|
path_in_repo = "/".join(args.resume.split("/")[2:])
|
||||||
results = loop.run_until_complete(
|
revision = None
|
||||||
asyncio.gather(
|
repo_type = None
|
||||||
*[download(filename=filename.rfilename) for filename in list_files]
|
if ":" in path_in_repo:
|
||||||
)
|
divided = path_in_repo.split(":")
|
||||||
)
|
if len(divided) == 2:
|
||||||
dirname = os.path.dirname(results[0])
|
path_in_repo, revision = divided
|
||||||
accelerator.load_state(dirname)
|
repo_type = "model"
|
||||||
else:
|
else:
|
||||||
accelerator.load_state(args.resume)
|
path_in_repo, revision, repo_type = divided
|
||||||
|
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
|
||||||
|
|
||||||
|
list_files = huggingface_util.list_dir(
|
||||||
|
repo_id=repo_id,
|
||||||
|
subfolder=path_in_repo,
|
||||||
|
revision=revision,
|
||||||
|
token=args.huggingface_token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download(filename) -> str:
|
||||||
|
def task():
|
||||||
|
return hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
revision=revision,
|
||||||
|
repo_type=repo_type,
|
||||||
|
token=args.huggingface_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, task)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
|
||||||
|
if len(results) == 0:
|
||||||
|
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
|
||||||
|
dirname = os.path.dirname(results[0])
|
||||||
|
accelerator.load_state(dirname)
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, trainable_params):
|
def get_optimizer(args, trainable_params):
|
||||||
@@ -2713,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
|
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
@@ -2883,6 +2900,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
accelerator.save_state(state_dir)
|
accelerator.save_state(state_dir)
|
||||||
if args.save_state_to_huggingface:
|
if args.save_state_to_huggingface:
|
||||||
|
print("uploading state to huggingface.")
|
||||||
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
|
||||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
@@ -2894,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
shutil.rmtree(state_dir_old)
|
shutil.rmtree(state_dir_old)
|
||||||
|
|
||||||
|
|
||||||
|
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||||
|
print("saving last state.")
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
|
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
|
||||||
|
accelerator.save_state(state_dir)
|
||||||
|
if args.save_state_to_huggingface:
|
||||||
|
print("uploading last state to huggingface.")
|
||||||
|
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
|
||||||
|
|
||||||
|
|
||||||
def save_sd_model_on_train_end(
|
def save_sd_model_on_train_end(
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
src_path: str,
|
src_path: str,
|
||||||
@@ -2932,13 +2961,6 @@ def save_sd_model_on_train_end(
|
|||||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
|
||||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
||||||
print("saving last state.")
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
|
||||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
|
||||||
|
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
SCHEDULER_LINEAR_END = 0.0120
|
SCHEDULER_LINEAR_END = 0.0120
|
||||||
@@ -3168,7 +3190,7 @@ class collater_class:
|
|||||||
def __init__(self, epoch, step, dataset):
|
def __init__(self, epoch, step, dataset):
|
||||||
self.current_epoch = epoch
|
self.current_epoch = epoch
|
||||||
self.current_step = step
|
self.current_step = step
|
||||||
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
||||||
|
|
||||||
def __call__(self, examples):
|
def __call__(self, examples):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -341,9 +341,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user