diff --git a/train_network.py b/train_network.py index 7563424b..54949f21 100644 --- a/train_network.py +++ b/train_network.py @@ -39,6 +39,7 @@ from library.custom_train_functions import ( apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments +from accelerate.utils import gather_object setup_logging() import logging