diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 85ee1dea..514f1a0a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -588,7 +588,7 @@ class WaveletLoss(torch.nn.Module): self.hl_weight2 = 0.01 self.hh_weight2 = 0.05 - assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`" + assert pywt.wavedec2 is not None, "PyWavelets module not available. Please install `pip install PyWavelets`" # Create GPU filters from wavelet wav = pywt.Wavelet(wavelet) self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device))