Save state when args.save_last_n_epochs_state is designated

This commit is contained in:
Yuta Hayashibe
2023-01-15 19:43:37 +09:00
parent a888223869
commit c6e28faa57

View File

@@ -1350,7 +1350,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
remove_old_func = remove_du
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
if saving and args.save_state:
if saving and args.save_state or args.save_last_n_epochs_state is not None:
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)