mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
change async uploading to optional
This commit is contained in:
@@ -20,11 +20,11 @@ def exists_repo(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@fire_in_thread
|
|
||||||
def upload(
|
def upload(
|
||||||
src: Union[str, Path, bytes, BinaryIO],
|
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
|
src: Union[str, Path, bytes, BinaryIO],
|
||||||
dest_suffix: str = "",
|
dest_suffix: str = "",
|
||||||
|
force_sync_upload: bool = False,
|
||||||
):
|
):
|
||||||
repo_id = args.huggingface_repo_id
|
repo_id = args.huggingface_repo_id
|
||||||
repo_type = args.huggingface_repo_type
|
repo_type = args.huggingface_repo_type
|
||||||
@@ -38,20 +38,27 @@ def upload(
|
|||||||
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||||
isinstance(src, Path) and src.is_dir()
|
isinstance(src, Path) and src.is_dir()
|
||||||
)
|
)
|
||||||
if is_folder:
|
|
||||||
api.upload_folder(
|
def uploader():
|
||||||
repo_id=repo_id,
|
if is_folder:
|
||||||
repo_type=repo_type,
|
api.upload_folder(
|
||||||
folder_path=src,
|
repo_id=repo_id,
|
||||||
path_in_repo=path_in_repo,
|
repo_type=repo_type,
|
||||||
)
|
folder_path=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api.upload_file(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
path_or_fileobj=src,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.async_upload and not force_sync_upload:
|
||||||
|
fire_in_thread(uploader)
|
||||||
else:
|
else:
|
||||||
api.upload_file(
|
uploader()
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type=repo_type,
|
|
||||||
path_or_fileobj=src,
|
|
||||||
path_in_repo=path_in_repo,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def list_dir(
|
def list_dir(
|
||||||
|
|||||||
@@ -1909,6 +1909,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--async_upload",
|
||||||
|
action="store_true",
|
||||||
|
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_precision",
|
"--save_precision",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -2831,7 +2836,7 @@ def save_sd_model_on_epoch_end(
|
|||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||||
)
|
)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_sd(old_epoch_no):
|
def remove_sd(old_epoch_no):
|
||||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||||
@@ -2852,7 +2857,7 @@ def save_sd_model_on_epoch_end(
|
|||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
)
|
)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(out_dir, args, "/" + model_name)
|
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||||
|
|
||||||
def remove_du(old_epoch_no):
|
def remove_du(old_epoch_no):
|
||||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||||
@@ -2873,7 +2878,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
|
|||||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
accelerator.save_state(state_dir)
|
accelerator.save_state(state_dir)
|
||||||
if args.save_state_to_huggingface:
|
if args.save_state_to_huggingface:
|
||||||
huggingface_util.upload(state_dir, args, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||||
|
|
||||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||||
if last_n_epochs is not None:
|
if last_n_epochs is not None:
|
||||||
@@ -2909,7 +2914,7 @@ def save_sd_model_on_train_end(
|
|||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
||||||
)
|
)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(args.output_dir, model_name)
|
out_dir = os.path.join(args.output_dir, model_name)
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
@@ -2919,7 +2924,7 @@ def save_sd_model_on_train_end(
|
|||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
)
|
)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(out_dir, args, "/" + model_name)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
|
||||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||||
|
|||||||
@@ -627,7 +627,7 @@ def train(args):
|
|||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -668,7 +668,7 @@ def train(args):
|
|||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -452,7 +452,7 @@ def train(args):
|
|||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -494,7 +494,7 @@ def train(args):
|
|||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -495,7 +495,7 @@ def train(args):
|
|||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
|
||||||
@@ -538,7 +538,7 @@ def train(args):
|
|||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user