diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 40ba51df..43871c84 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -804,6 +804,9 @@ class WaveletLoss(nn.Module): # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: + combined_hf_pred = self._pad_tensors(combined_hf_pred) + combined_hf_target = self._pad_tensors(combined_hf_target) + combined_hf_pred = torch.cat(combined_hf_pred, dim=1) combined_hf_target = torch.cat(combined_hf_target, dim=1) else: