change method name, repo is private in default etc

This commit is contained in:
Kohya S
2023-04-05 23:16:49 +09:00
parent 74220bb52c
commit 541539a144
7 changed files with 88 additions and 68 deletions

View File

@@ -231,7 +231,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume(accelerator, args)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -30,7 +30,7 @@ def upload(
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility == "private"
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)

View File

@@ -1898,12 +1898,28 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument("--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名")
parser.add_argument("--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類")
parser.add_argument("--huggingface_path_in_repo", type=str, default=None, help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス")
parser.add_argument(
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ"
)
parser.add_argument(
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
)
parser.add_argument(
"--huggingface_path_in_repo",
type=str,
default=None,
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
)
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
parser.add_argument("--huggingface_repo_visibility", type=str, default=None, help="huggingface repository visibility / huggingfaceにアップロードするリポジトリの公開設定")
parser.add_argument("--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する")
parser.add_argument(
"--huggingface_repo_visibility",
type=str,
default=None,
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定'public'で公開、'private'またはNoneで非公開",
)
parser.add_argument(
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
)
parser.add_argument(
"--resume_from_huggingface",
action="store_true",
@@ -2278,10 +2294,17 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
# region utils
def resume(accelerator, args):
if args.resume:
print(f"resume training from state: {args.resume}")
if args.resume_from_huggingface:
def resume_from_local_or_hf_if_specified(accelerator, args):
if not args.resume:
return
if not args.resume_from_huggingface:
print(f"resume training from local state: {args.resume}")
accelerator.load_state(args.resume)
return
print(f"resume training from huggingface state: {args.resume}")
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
path_in_repo = "/".join(args.resume.split("/")[2:])
revision = None
@@ -2293,9 +2316,7 @@ def resume(accelerator, args):
repo_type = "model"
else:
path_in_repo, revision, repo_type = divided
print(
f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}"
)
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
list_files = huggingface_util.list_dir(
repo_id=repo_id,
@@ -2318,15 +2339,11 @@ def resume(accelerator, args):
return await asyncio.get_event_loop().run_in_executor(None, task)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(
asyncio.gather(
*[download(filename=filename.rfilename) for filename in list_files]
)
)
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
if len(results) == 0:
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
else:
accelerator.load_state(args.resume)
def get_optimizer(args, trainable_params):
@@ -2713,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -2883,6 +2900,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
@@ -2894,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
def save_sd_model_on_train_end(
args: argparse.Namespace,
src_path: str,
@@ -2932,13 +2961,6 @@ def save_sd_model_on_train_end(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120

View File

@@ -202,7 +202,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume(accelerator, args)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -310,7 +310,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume(accelerator, args)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -305,7 +305,7 @@ def train(args):
text_encoder.to(weight_dtype)
# resumeする
train_util.resume(accelerator, args)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

View File

@@ -341,9 +341,7 @@ def train(args):
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)