Moved matplotlib into functions to not require it

This commit is contained in:
rockerBOO
2025-07-14 21:26:36 -04:00
parent 8cc81e45f7
commit 121cc23f2e

View File

@@ -13,7 +13,6 @@ from torch.types import Number
from typing import List, Optional, Union, Protocol
from .utils import setup_logging
import matplotlib.pyplot as plt
try:
import pywt
@@ -1736,6 +1735,7 @@ def explore_wavelets(coeffs, coeffs_name="Coefficients"):
# During training, visualize specific coefficients
def visualize_training_wavelets(pred_coeffs, target_coeffs, step):
"""Call this during training to save wavelet visualizations"""
import matplotlib.pyplot as plt
# 1. Visualize predicted coefficients for LH band, level 0
fig1 = visualize_wavelet_coefficients(
@@ -1765,6 +1765,7 @@ def visualize_all_bands_levels(coeffs, title_prefix="", batch_idx=0,
"""
Show all wavelet bands and levels in one overview plot
"""
import matplotlib.pyplot as plt
bands = ['lh', 'hl', 'hh']
n_levels = len(coeffs['lh']) # Assuming all bands have same levels
@@ -1810,6 +1811,7 @@ def compare_wavelet_coefficients(pred_coeffs, target_coeffs, band, level,
"""
Side-by-side comparison of predicted vs target coefficients
"""
import matplotlib.pyplot as plt
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
@@ -1873,6 +1875,7 @@ def visualize_wavelet_coefficients(coeffs, band, level, batch_idx=0,
Returns:
fig: matplotlib figure object
"""
import matplotlib.pyplot as plt
# Extract the specific coefficients
coeff_tensor = coeffs[band][level] # [batch, channel, h, w]