diff --git a/library/train_util.py b/library/train_util.py index 3a7c2c8a..aa65dc3c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1299,7 +1299,6 @@ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs - remove_epoch_no = None if saving: os.makedirs(args.output_dir, exist_ok=True) save_func() @@ -1356,12 +1355,9 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - remove_epoch_no = None - last_n_epoch = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs - if last_n_epoch is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch - - if remove_epoch_no is not None: + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) if os.path.exists(state_dir_old): print(f"removing old state: {state_dir_old}") diff --git a/train_network.py b/train_network.py index 03fd01e7..b2c7b579 100644 --- a/train_network.py +++ b/train_network.py @@ -367,9 +367,9 @@ def train(args): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) # end of epoch