diff --git a/train_network.py b/train_network.py index 32d6640d..2eabcdf3 100644 --- a/train_network.py +++ b/train_network.py @@ -43,9 +43,9 @@ from library.custom_train_functions import ( add_v_prediction_like_loss, apply_debiased_estimation, apply_masked_loss, - WaveletLoss ) from library.utils import setup_logging, add_logging_arguments +from wavelet_loss import WaveletLoss setup_logging() import logging