diff --git a/library/train_util.py b/library/train_util.py index 0bd87bc8..5d68c6ac 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3167,10 +3167,11 @@ def save_sd_model_on_epoch_end_or_stepwise( print(f"removing old model: {remove_out_dir}") shutil.rmtree(remove_out_dir) - if on_epoch_end: - save_and_remove_state_on_epoch_end(args, accelerator, epoch_no) - else: - save_and_remove_state_stepwise(args, accelerator, global_step) + if args.save_state: + if on_epoch_end: + save_and_remove_state_on_epoch_end(args, accelerator, epoch_no) + else: + save_and_remove_state_stepwise(args, accelerator, global_step) def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):