diff --git a/train_network.py b/train_network.py index 2eabcdf3..18343ce3 100644 --- a/train_network.py +++ b/train_network.py @@ -45,7 +45,11 @@ from library.custom_train_functions import ( apply_masked_loss, ) from library.utils import setup_logging, add_logging_arguments -from wavelet_loss import WaveletLoss + +try: + from wavelet_loss import WaveletLoss +except: + raise ImportError("wavelet-loss is not installed. Install it with `pip install git+https://github.com/rockerBOO/wavelet-loss`") setup_logging() import logging