diff --git a/library/train_util.py b/library/train_util.py index 6170782b..b9c4199b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1359,7 +1359,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) remove_epoch_no = None - last_n_epoch = args.save_last_n_epochs_model if args.save_last_n_epochs_state else args.save_last_n_epochs + last_n_epoch = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs if last_n_epoch is not None: remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epoch