diff --git a/train_network.py b/train_network.py index 4707d5ae..3db583f1 100644 --- a/train_network.py +++ b/train_network.py @@ -935,7 +935,7 @@ class NetworkTrainer: accelerator.end_training() - if is_main_process and args.save_state or args.save_state_on_train_end: + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) if is_main_process: