mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
change method name, repo is private in default etc
This commit is contained in:
@@ -231,7 +231,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ def upload(
|
|||||||
repo_type = args.huggingface_repo_type
|
repo_type = args.huggingface_repo_type
|
||||||
token = args.huggingface_token
|
token = args.huggingface_token
|
||||||
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
path_in_repo = args.huggingface_path_in_repo + dest_suffix
|
||||||
private = args.huggingface_repo_visibility == "private"
|
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||||
api = HfApi(token=token)
|
api = HfApi(token=token)
|
||||||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||||
|
|||||||
@@ -1898,12 +1898,28 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
||||||
parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名")
|
parser.add_argument(
|
||||||
parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類")
|
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
|
||||||
parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--huggingface_path_in_repo",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
|
||||||
|
)
|
||||||
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
||||||
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface repository visibility / huggingfaceにアップロードするリポジトリの公開設定")
|
parser.add_argument(
|
||||||
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
|
"--huggingface_repo_visibility",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume_from_huggingface",
|
"--resume_from_huggingface",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -2278,10 +2294,17 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
|||||||
|
|
||||||
# region utils
|
# region utils
|
||||||
|
|
||||||
def resume(accelerator, args):
|
|
||||||
if args.resume:
|
def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||||
print(f"resume training from state: {args.resume}")
|
if not args.resume:
|
||||||
if args.resume_from_huggingface:
|
return
|
||||||
|
|
||||||
|
if not args.resume_from_huggingface:
|
||||||
|
print(f"resume training from local state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"resume training from huggingface state: {args.resume}")
|
||||||
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
||||||
path_in_repo = "/".join(args.resume.split("/")[2:])
|
path_in_repo = "/".join(args.resume.split("/")[2:])
|
||||||
revision = None
|
revision = None
|
||||||
@@ -2293,9 +2316,7 @@ def resume(accelerator, args):
|
|||||||
repo_type = "model"
|
repo_type = "model"
|
||||||
else:
|
else:
|
||||||
path_in_repo, revision, repo_type = divided
|
path_in_repo, revision, repo_type = divided
|
||||||
print(
|
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
|
||||||
f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}"
|
|
||||||
)
|
|
||||||
|
|
||||||
list_files = huggingface_util.list_dir(
|
list_files = huggingface_util.list_dir(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
@@ -2318,15 +2339,11 @@ def resume(accelerator, args):
|
|||||||
return await asyncio.get_event_loop().run_in_executor(None, task)
|
return await asyncio.get_event_loop().run_in_executor(None, task)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
results = loop.run_until_complete(
|
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
|
||||||
asyncio.gather(
|
if len(results) == 0:
|
||||||
*[download(filename=filename.rfilename) for filename in list_files]
|
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
|
||||||
)
|
|
||||||
)
|
|
||||||
dirname = os.path.dirname(results[0])
|
dirname = os.path.dirname(results[0])
|
||||||
accelerator.load_state(dirname)
|
accelerator.load_state(dirname)
|
||||||
else:
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, trainable_params):
|
def get_optimizer(args, trainable_params):
|
||||||
@@ -2713,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
|
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
@@ -2883,6 +2900,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
accelerator.save_state(state_dir)
|
accelerator.save_state(state_dir)
|
||||||
if args.save_state_to_huggingface:
|
if args.save_state_to_huggingface:
|
||||||
|
print("uploading state to huggingface.")
|
||||||
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
|
||||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
@@ -2894,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
shutil.rmtree(state_dir_old)
|
shutil.rmtree(state_dir_old)
|
||||||
|
|
||||||
|
|
||||||
|
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||||
|
print("saving last state.")
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
|
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
|
||||||
|
accelerator.save_state(state_dir)
|
||||||
|
if args.save_state_to_huggingface:
|
||||||
|
print("uploading last state to huggingface.")
|
||||||
|
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
|
||||||
|
|
||||||
|
|
||||||
def save_sd_model_on_train_end(
|
def save_sd_model_on_train_end(
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
src_path: str,
|
src_path: str,
|
||||||
@@ -2932,13 +2961,6 @@ def save_sd_model_on_train_end(
|
|||||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
|
||||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
|
||||||
print("saving last state.")
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
|
||||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
|
||||||
|
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
SCHEDULER_LINEAR_END = 0.0120
|
SCHEDULER_LINEAR_END = 0.0120
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -310,7 +310,7 @@ def train(args):
|
|||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
train_util.resume(accelerator, args)
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
@@ -341,9 +341,7 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
print(f"resume training from state: {args.resume}")
|
|
||||||
accelerator.load_state(args.resume)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user