resume from huggingface repository

This commit is contained in:
ddPn08
2023-03-30 23:36:42 +09:00
parent a7d302e196
commit b5ff4e816f
8 changed files with 139 additions and 77 deletions

View File

@@ -24,7 +24,7 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.utils as utils
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
@@ -285,9 +285,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -628,7 +626,7 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -668,7 +666,7 @@ def train(args):
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
utils.huggingface_upload(ckpt_file, args, "/" + ckpt_name)
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
print("model saved.")