From 121cc23f2eee736621ac7e039dea50d32fd88833 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 14 Jul 2025 21:26:36 -0400 Subject: [PATCH] Moved matplotlib into functions to not require it --- library/custom_train_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 549d4f7b..0c9b593a 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -13,7 +13,6 @@ from torch.types import Number from typing import List, Optional, Union, Protocol from .utils import setup_logging -import matplotlib.pyplot as plt try: import pywt @@ -1736,6 +1735,7 @@ def explore_wavelets(coeffs, coeffs_name="Coefficients"): # During training, visualize specific coefficients def visualize_training_wavelets(pred_coeffs, target_coeffs, step): """Call this during training to save wavelet visualizations""" + import matplotlib.pyplot as plt # 1. Visualize predicted coefficients for LH band, level 0 fig1 = visualize_wavelet_coefficients( @@ -1765,6 +1765,7 @@ def visualize_all_bands_levels(coeffs, title_prefix="", batch_idx=0, """ Show all wavelet bands and levels in one overview plot """ + import matplotlib.pyplot as plt bands = ['lh', 'hl', 'hh'] n_levels = len(coeffs['lh']) # Assuming all bands have same levels @@ -1810,6 +1811,7 @@ def compare_wavelet_coefficients(pred_coeffs, target_coeffs, band, level, """ Side-by-side comparison of predicted vs target coefficients """ + import matplotlib.pyplot as plt fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) @@ -1873,6 +1875,7 @@ def visualize_wavelet_coefficients(coeffs, band, level, batch_idx=0, Returns: fig: matplotlib figure object """ + import matplotlib.pyplot as plt # Extract the specific coefficients coeff_tensor = coeffs[band][level] # [batch, channel, h, w]