add save_last_n_epochs_state to train_network

This commit is contained in:
Kohya S
2023-01-19 20:59:45 +09:00
parent 8bd844cdc1
commit 758323532b
2 changed files with 5 additions and 9 deletions

View File

@@ -367,9 +367,9 @@ def train(args):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# end of epoch