mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add save_state_on_train end, fix reg imgs repeats
This commit is contained in:
@@ -302,22 +302,19 @@ def train(args):
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1)))
|
||||
if remove_epoch_no is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
shutil.rmtree(state_dir_old)
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user