mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Moved matplotlib into functions to not require it
This commit is contained in:
@@ -13,7 +13,6 @@ from torch.types import Number
|
||||
from typing import List, Optional, Union, Protocol
|
||||
from .utils import setup_logging
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import pywt
|
||||
@@ -1736,6 +1735,7 @@ def explore_wavelets(coeffs, coeffs_name="Coefficients"):
|
||||
# During training, visualize specific coefficients
|
||||
def visualize_training_wavelets(pred_coeffs, target_coeffs, step):
|
||||
"""Call this during training to save wavelet visualizations"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 1. Visualize predicted coefficients for LH band, level 0
|
||||
fig1 = visualize_wavelet_coefficients(
|
||||
@@ -1765,6 +1765,7 @@ def visualize_all_bands_levels(coeffs, title_prefix="", batch_idx=0,
|
||||
"""
|
||||
Show all wavelet bands and levels in one overview plot
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
bands = ['lh', 'hl', 'hh']
|
||||
n_levels = len(coeffs['lh']) # Assuming all bands have same levels
|
||||
@@ -1810,6 +1811,7 @@ def compare_wavelet_coefficients(pred_coeffs, target_coeffs, band, level,
|
||||
"""
|
||||
Side-by-side comparison of predicted vs target coefficients
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
|
||||
|
||||
@@ -1873,6 +1875,7 @@ def visualize_wavelet_coefficients(coeffs, band, level, batch_idx=0,
|
||||
Returns:
|
||||
fig: matplotlib figure object
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Extract the specific coefficients
|
||||
coeff_tensor = coeffs[band][level] # [batch, channel, h, w]
|
||||
|
||||
Reference in New Issue
Block a user