|
|
|
|
@@ -7,11 +7,14 @@ import re
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import numpy as np
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from torch.types import Number
|
|
|
|
|
from typing import List, Optional, Union, Protocol
|
|
|
|
|
from .utils import setup_logging
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import pywt
|
|
|
|
|
except:
|
|
|
|
|
@@ -1064,7 +1067,7 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
energy_ratio: float = 0.0,
|
|
|
|
|
energy_scale_factor: float = 0.01,
|
|
|
|
|
normalize_bands: bool = True,
|
|
|
|
|
max_timestep: float = 1.0,
|
|
|
|
|
max_timestep: float = 1000,
|
|
|
|
|
timestep_intensity: float = 0.5,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
@@ -1156,13 +1159,10 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
band_weights = self.band_weights
|
|
|
|
|
band_level_weights = self.band_level_weights
|
|
|
|
|
|
|
|
|
|
# Apply timestep-based weighting if provided
|
|
|
|
|
# if timestep is not None:
|
|
|
|
|
# # Let users control intensity of timestep weighting (0.5 = moderate effect)
|
|
|
|
|
# intensity = getattr(self, "timestep_intensity", 0.5)
|
|
|
|
|
# current_band_weights, current_band_level_weights = self.noise_aware_weighting(
|
|
|
|
|
# timestep, self.max_timestep, intensity=intensity
|
|
|
|
|
# )
|
|
|
|
|
base_weight = torch.ones((batch_size), device=device)
|
|
|
|
|
if timestep is not None:
|
|
|
|
|
base_weight *= self.smooth_timestep_weight(timestep)
|
|
|
|
|
metrics['wavelet_loss/avg_timestep_adjusted_weight'] = base_weight.detach().mean().item()
|
|
|
|
|
|
|
|
|
|
# If negative it's from the end of the levels else it's the level.
|
|
|
|
|
ll_threshold = None
|
|
|
|
|
@@ -1180,6 +1180,8 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
weight_key = f"{band}{i+1}"
|
|
|
|
|
pred = pred_coeffs[band][i]
|
|
|
|
|
target = target_coeffs[band][i]
|
|
|
|
|
|
|
|
|
|
if band in pred_coeffs and band in target_coeffs:
|
|
|
|
|
if self.normalize_bands:
|
|
|
|
|
@@ -1187,9 +1189,34 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
pred_coeffs[band][i] = (pred_coeffs[band][i] - pred_coeffs[band][i].mean()) / (pred_coeffs[band][i].std() + 1e-8)
|
|
|
|
|
target_coeffs[band][i] = (target_coeffs[band][i] - target_coeffs[band][i].mean()) / (target_coeffs[band][i].std() + 1e-8)
|
|
|
|
|
|
|
|
|
|
weight = band_level_weights.get(weight_key, band_weights[band])
|
|
|
|
|
band_loss = weight * self.loss_fn(pred_coeffs[band][i], target_coeffs[band][i])
|
|
|
|
|
pattern_level_losses += band_loss.mean(dim=0) # mean stack dim
|
|
|
|
|
# 1. Magnitude loss
|
|
|
|
|
band_loss = self.loss_fn(pred, target)
|
|
|
|
|
|
|
|
|
|
# 2. Local structure loss
|
|
|
|
|
pred_grad_x = torch.diff(pred, dim=-1)
|
|
|
|
|
pred_grad_y = torch.diff(pred, dim=-2)
|
|
|
|
|
target_grad_x = torch.diff(target, dim=-1)
|
|
|
|
|
target_grad_y = torch.diff(target, dim=-2)
|
|
|
|
|
|
|
|
|
|
gradient_loss = F.mse_loss(pred_grad_x, target_grad_x) + \
|
|
|
|
|
F.mse_loss(pred_grad_y, target_grad_y)
|
|
|
|
|
|
|
|
|
|
# 3. Global correlation per channel
|
|
|
|
|
B, C = pred.shape[:2]
|
|
|
|
|
pred_flat = pred.view(B, C, -1)
|
|
|
|
|
target_flat = target.view(B, C, -1)
|
|
|
|
|
|
|
|
|
|
cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=2)
|
|
|
|
|
correlation_loss = (1 - cos_sim).mean()
|
|
|
|
|
|
|
|
|
|
weight = base_weight * band_level_weights.get(weight_key, band_weights[band])
|
|
|
|
|
pattern_level_losses += weight.view(-1, 1, 1, 1) * (band_loss +
|
|
|
|
|
0.05 * gradient_loss +
|
|
|
|
|
0.1 * correlation_loss) # mean stack dim
|
|
|
|
|
|
|
|
|
|
metrics[f"{band}{i}_band_loss"] = band_loss.detach().mean().item()
|
|
|
|
|
metrics[f"{band}{i}_gradient_loss"] = gradient_loss.detach().mean().item()
|
|
|
|
|
metrics[f"{band}{i}_correlation_loss"] = correlation_loss.detach().mean().item()
|
|
|
|
|
|
|
|
|
|
# Collect high frequency bands for visualization
|
|
|
|
|
combined_hf_pred.append(pred_coeffs[band][i])
|
|
|
|
|
@@ -1405,36 +1432,32 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict:
|
|
|
|
|
"""Calculate correlation metrics between prediction and target wavelet coefficients"""
|
|
|
|
|
metrics = {}
|
|
|
|
|
avg_correlations = []
|
|
|
|
|
|
|
|
|
|
for band in ["lh", "hl", "hh"]:
|
|
|
|
|
for i in range(1, self.level + 1):
|
|
|
|
|
# Get coefficients
|
|
|
|
|
pred = pred_coeffs[band][i - 1]
|
|
|
|
|
target = target_coeffs[band][i - 1]
|
|
|
|
|
band_correlations = []
|
|
|
|
|
for i in range(self.level):
|
|
|
|
|
pred = pred_coeffs[band][i] # [B, C, H, W]
|
|
|
|
|
target = target_coeffs[band][i]
|
|
|
|
|
|
|
|
|
|
# Flatten for batch-wise correlation
|
|
|
|
|
batch_size = pred.shape[0]
|
|
|
|
|
pred_flat = pred.view(batch_size, -1)
|
|
|
|
|
target_flat = target.view(batch_size, -1)
|
|
|
|
|
# Flatten spatial dims but keep batch/channel separate
|
|
|
|
|
pred_flat = pred.flatten(start_dim=2) # [B, C, H*W]
|
|
|
|
|
target_flat = target.flatten(start_dim=2)
|
|
|
|
|
|
|
|
|
|
# Center data
|
|
|
|
|
pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True)
|
|
|
|
|
target_centered = target_flat - target_flat.mean(dim=1, keepdim=True)
|
|
|
|
|
# Calculate correlation across spatial dimension
|
|
|
|
|
pred_centered = pred_flat - pred_flat.mean(dim=2, keepdim=True)
|
|
|
|
|
target_centered = target_flat - target_flat.mean(dim=2, keepdim=True)
|
|
|
|
|
|
|
|
|
|
# Calculate correlation
|
|
|
|
|
numerator = torch.sum(pred_centered * target_centered, dim=1)
|
|
|
|
|
denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8)
|
|
|
|
|
correlation = numerator / denominator
|
|
|
|
|
numerator = torch.sum(pred_centered * target_centered, dim=2)
|
|
|
|
|
denom = torch.sqrt(torch.sum(pred_centered**2, dim=2) *
|
|
|
|
|
torch.sum(target_centered**2, dim=2) + 1e-8)
|
|
|
|
|
|
|
|
|
|
# Average across batch
|
|
|
|
|
avg_correlation = correlation.mean().item()
|
|
|
|
|
metrics[f"{band}{i}_correlation"] = avg_correlation
|
|
|
|
|
avg_correlations.append(avg_correlation)
|
|
|
|
|
correlation = numerator / denom # [B, C]
|
|
|
|
|
avg_corr = correlation.mean().item()
|
|
|
|
|
|
|
|
|
|
# Calculate average correlation across all bands
|
|
|
|
|
if avg_correlations:
|
|
|
|
|
metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations)
|
|
|
|
|
metrics[f"{band}{i+1}_spatial_correlation"] = avg_corr
|
|
|
|
|
band_correlations.append(avg_corr)
|
|
|
|
|
|
|
|
|
|
metrics[f"{band}_avg_correlation"] = np.mean(band_correlations)
|
|
|
|
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
@@ -1547,12 +1570,20 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
|
|
|
|
|
# Average sparsity across bands
|
|
|
|
|
if band_sparsities:
|
|
|
|
|
metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities)
|
|
|
|
|
if band_non_zero_ratios: # Add this
|
|
|
|
|
metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios)
|
|
|
|
|
metrics["avg_sparsity_score"] = 1.0 / (sum(band_sparsities) / len(band_sparsities) + 1e-8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
def smooth_timestep_weight(self, timestep):
|
|
|
|
|
"""Smooth weight transition instead of hard cutoff"""
|
|
|
|
|
|
|
|
|
|
progress = 1.0 - (timestep / self.max_timestep)
|
|
|
|
|
|
|
|
|
|
weight = torch.sigmoid((progress - 0.3) * 10)
|
|
|
|
|
|
|
|
|
|
return weight
|
|
|
|
|
|
|
|
|
|
# TODO: does not work right in terms of weighting in an appropriate range
|
|
|
|
|
def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0):
|
|
|
|
|
"""
|
|
|
|
|
@@ -1680,6 +1711,244 @@ class WaveletLoss(nn.Module):
|
|
|
|
|
self.loss_fn = loss_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def explore_wavelets(coeffs, coeffs_name="Coefficients"):
|
|
|
|
|
"""Interactive exploration of wavelet coefficients"""
|
|
|
|
|
|
|
|
|
|
bands = list(coeffs.keys())
|
|
|
|
|
levels = list(range(len(coeffs[bands[0]])))
|
|
|
|
|
batch_size, n_channels = coeffs[bands[0]][0].shape[:2]
|
|
|
|
|
|
|
|
|
|
print(f"\n=== {coeffs_name} Structure ===")
|
|
|
|
|
print(f"Bands: {bands}")
|
|
|
|
|
print(f"Levels: {levels}")
|
|
|
|
|
print(f"Batch size: {batch_size}")
|
|
|
|
|
print(f"Channels: {n_channels}")
|
|
|
|
|
|
|
|
|
|
for band in bands:
|
|
|
|
|
for level in levels:
|
|
|
|
|
shape = coeffs[band][level].shape
|
|
|
|
|
sparsity = (torch.abs(coeffs[band][level]) < 0.01).float().mean().item()
|
|
|
|
|
magnitude = torch.abs(coeffs[band][level]).mean().item()
|
|
|
|
|
|
|
|
|
|
print(f"{band.upper()}{level+1}: shape={shape}, "
|
|
|
|
|
f"sparsity={sparsity:.1%}, avg_magnitude={magnitude:.4f}")
|
|
|
|
|
|
|
|
|
|
# During training, visualize specific coefficients
|
|
|
|
|
def visualize_training_wavelets(pred_coeffs, target_coeffs, step):
|
|
|
|
|
"""Call this during training to save wavelet visualizations"""
|
|
|
|
|
|
|
|
|
|
# 1. Visualize predicted coefficients for LH band, level 0
|
|
|
|
|
fig1 = visualize_wavelet_coefficients(
|
|
|
|
|
pred_coeffs, band='lh', level=0, batch_idx=0,
|
|
|
|
|
title_prefix="Predicted",
|
|
|
|
|
save_path=f"wavelets/pred_lh1_step_{step}.png"
|
|
|
|
|
)
|
|
|
|
|
plt.close(fig1)
|
|
|
|
|
|
|
|
|
|
# 2. Compare predicted vs target
|
|
|
|
|
fig2 = compare_wavelet_coefficients(
|
|
|
|
|
pred_coeffs, target_coeffs, band='hl', level=1,
|
|
|
|
|
batch_idx=0, channel_idx=0,
|
|
|
|
|
save_path=f"wavelets/comparison_hl2_step_{step}.png"
|
|
|
|
|
)
|
|
|
|
|
plt.close(fig2)
|
|
|
|
|
|
|
|
|
|
# 3. Overview of all bands
|
|
|
|
|
fig3 = visualize_all_bands_levels(
|
|
|
|
|
pred_coeffs, title_prefix="Predicted", batch_idx=0, channel_idx=0,
|
|
|
|
|
save_path=f"wavelets/overview_step_{step}.png"
|
|
|
|
|
)
|
|
|
|
|
plt.close(fig3)
|
|
|
|
|
|
|
|
|
|
def visualize_all_bands_levels(coeffs, title_prefix="", batch_idx=0,
|
|
|
|
|
channel_idx=0, save_path=None):
|
|
|
|
|
"""
|
|
|
|
|
Show all wavelet bands and levels in one overview plot
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
bands = ['lh', 'hl', 'hh']
|
|
|
|
|
n_levels = len(coeffs['lh']) # Assuming all bands have same levels
|
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(len(bands), n_levels, figsize=(4*n_levels, 3*len(bands)))
|
|
|
|
|
|
|
|
|
|
if n_levels == 1:
|
|
|
|
|
axes = axes.reshape(-1, 1)
|
|
|
|
|
|
|
|
|
|
for band_idx, band in enumerate(bands):
|
|
|
|
|
for level in range(n_levels):
|
|
|
|
|
ax = axes[band_idx, level]
|
|
|
|
|
|
|
|
|
|
# Get coefficient data
|
|
|
|
|
coeff_data = coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
# Plot
|
|
|
|
|
im = ax.imshow(coeff_data, cmap='RdBu_r', aspect='auto')
|
|
|
|
|
ax.set_title(f'{band.upper()}{level+1}')
|
|
|
|
|
|
|
|
|
|
# Add colorbar for better interpretation
|
|
|
|
|
plt.colorbar(im, ax=ax, shrink=0.6)
|
|
|
|
|
|
|
|
|
|
# Add sparsity info
|
|
|
|
|
sparsity = (np.abs(coeff_data) < 0.01).mean()
|
|
|
|
|
ax.text(0.02, 0.02, f'Sparse: {sparsity:.1%}',
|
|
|
|
|
transform=ax.transAxes, bbox=dict(boxstyle='round',
|
|
|
|
|
facecolor='white', alpha=0.8), fontsize=8)
|
|
|
|
|
|
|
|
|
|
fig.suptitle(f'{title_prefix} All Wavelet Bands - Sample {batch_idx}, Channel {channel_idx}',
|
|
|
|
|
fontsize=14)
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_wavelet_coefficients(pred_coeffs, target_coeffs, band, level,
|
|
|
|
|
batch_idx=0, channel_idx=0, save_path=None):
|
|
|
|
|
"""
|
|
|
|
|
Side-by-side comparison of predicted vs target coefficients
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
|
|
|
|
|
|
|
|
|
|
# Get data
|
|
|
|
|
pred_data = pred_coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy()
|
|
|
|
|
target_data = target_coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
# Calculate difference
|
|
|
|
|
diff_data = pred_data - target_data
|
|
|
|
|
|
|
|
|
|
# Determine common color scale
|
|
|
|
|
vmin = min(pred_data.min(), target_data.min())
|
|
|
|
|
vmax = max(pred_data.max(), target_data.max())
|
|
|
|
|
|
|
|
|
|
# Plot predicted
|
|
|
|
|
im1 = ax1.imshow(pred_data, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
|
|
|
|
ax1.set_title(f'Predicted {band.upper()}{level+1} Ch{channel_idx}')
|
|
|
|
|
plt.colorbar(im1, ax=ax1, shrink=0.8)
|
|
|
|
|
|
|
|
|
|
# Plot target
|
|
|
|
|
im2 = ax2.imshow(target_data, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
|
|
|
|
ax2.set_title(f'Target {band.upper()}{level+1} Ch{channel_idx}')
|
|
|
|
|
plt.colorbar(im2, ax=ax2, shrink=0.8)
|
|
|
|
|
|
|
|
|
|
# Plot difference
|
|
|
|
|
im3 = ax3.imshow(diff_data, cmap='RdBu_r', vmin=-np.abs(diff_data).max(),
|
|
|
|
|
vmax=np.abs(diff_data).max())
|
|
|
|
|
ax3.set_title('Difference (Pred - Target)')
|
|
|
|
|
plt.colorbar(im3, ax=ax3, shrink=0.8)
|
|
|
|
|
|
|
|
|
|
# Add correlation info
|
|
|
|
|
correlation = np.corrcoef(pred_data.flatten(), target_data.flatten())[0,1]
|
|
|
|
|
mse = np.mean((pred_data - target_data)**2)
|
|
|
|
|
|
|
|
|
|
fig.suptitle(f'Wavelet Comparison - Correlation: {correlation:.3f}, MSE: {mse:.6f}',
|
|
|
|
|
fontsize=14)
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
return fig
|
|
|
|
|
|
|
|
|
|
def visualize_wavelet_coefficients(coeffs, band, level, batch_idx=0,
|
|
|
|
|
channel_idx=None, title_prefix="",
|
|
|
|
|
save_path=None, figsize=(15, 10)):
|
|
|
|
|
"""
|
|
|
|
|
Visualize wavelet coefficients for a specific band and level
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
coeffs: dict with structure coeffs[band][level] -> [batch, channel, h, w]
|
|
|
|
|
band: str, one of ['lh', 'hl', 'hh']
|
|
|
|
|
level: int, wavelet decomposition level (0-indexed)
|
|
|
|
|
batch_idx: int, which sample in batch to visualize
|
|
|
|
|
channel_idx: int or None, specific channel to show (None = all channels)
|
|
|
|
|
title_prefix: str, prefix for plot titles (e.g., "Predicted" or "Target")
|
|
|
|
|
save_path: str or None, path to save the plot
|
|
|
|
|
figsize: tuple, figure size
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
fig: matplotlib figure object
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Extract the specific coefficients
|
|
|
|
|
coeff_tensor = coeffs[band][level] # [batch, channel, h, w]
|
|
|
|
|
|
|
|
|
|
# Get single sample
|
|
|
|
|
sample_coeffs = coeff_tensor[batch_idx] # [channel, h, w]
|
|
|
|
|
|
|
|
|
|
batch_size, num_channels, height, width = coeff_tensor.shape
|
|
|
|
|
|
|
|
|
|
# Determine which channels to visualize
|
|
|
|
|
if channel_idx is not None:
|
|
|
|
|
channels_to_show = [channel_idx]
|
|
|
|
|
sample_coeffs = sample_coeffs[channel_idx:channel_idx+1]
|
|
|
|
|
else:
|
|
|
|
|
channels_to_show = list(range(num_channels))
|
|
|
|
|
|
|
|
|
|
# Create subplot layout
|
|
|
|
|
n_channels = len(channels_to_show)
|
|
|
|
|
cols = min(4, n_channels) # Max 4 columns
|
|
|
|
|
rows = (n_channels + cols - 1) // cols # Ceiling division
|
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(rows, cols, figsize=figsize)
|
|
|
|
|
|
|
|
|
|
# Handle single subplot case
|
|
|
|
|
if n_channels == 1:
|
|
|
|
|
axes = [axes]
|
|
|
|
|
elif rows == 1:
|
|
|
|
|
axes = [axes] if n_channels == 1 else axes
|
|
|
|
|
else:
|
|
|
|
|
axes = axes.flatten()
|
|
|
|
|
|
|
|
|
|
# Plot each channel
|
|
|
|
|
for i, ch_idx in enumerate(channels_to_show):
|
|
|
|
|
if i >= len(axes):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
ax = axes[i]
|
|
|
|
|
|
|
|
|
|
# Get coefficient data for this channel
|
|
|
|
|
coeff_data = sample_coeffs[i].detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
# Create visualization
|
|
|
|
|
im = ax.imshow(coeff_data, cmap='RdBu_r', aspect='auto')
|
|
|
|
|
|
|
|
|
|
# Add colorbar
|
|
|
|
|
plt.colorbar(im, ax=ax, shrink=0.8)
|
|
|
|
|
|
|
|
|
|
# Set title
|
|
|
|
|
ax.set_title(f'{title_prefix} {band.upper()}{level+1} Ch{ch_idx}\n'
|
|
|
|
|
f'Range: [{coeff_data.min():.3f}, {coeff_data.max():.3f}]')
|
|
|
|
|
|
|
|
|
|
# Add statistics text
|
|
|
|
|
stats_text = f'Mean: {coeff_data.mean():.3f}\n' \
|
|
|
|
|
f'Std: {coeff_data.std():.3f}\n' \
|
|
|
|
|
f'Non-zero: {(np.abs(coeff_data) > 0.01).mean():.1%}'
|
|
|
|
|
|
|
|
|
|
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
|
|
|
|
|
verticalalignment='top', bbox=dict(boxstyle='round',
|
|
|
|
|
facecolor='white', alpha=0.8), fontsize=8)
|
|
|
|
|
|
|
|
|
|
# Hide unused subplots
|
|
|
|
|
for i in range(n_channels, len(axes)):
|
|
|
|
|
axes[i].axis('off')
|
|
|
|
|
|
|
|
|
|
# Add main title
|
|
|
|
|
fig.suptitle(f'{title_prefix} Wavelet Coefficients - {band.upper()} Level {level+1}\n'
|
|
|
|
|
f'Sample {batch_idx}, Shape: {coeff_tensor.shape}',
|
|
|
|
|
fontsize=14, fontweight='bold')
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
return fig
|
|
|
|
|
|
|
|
|
|
def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename):
|
|
|
|
|
"""
|
|
|
|
|
Visualize QWT decomposition of input, prediction, and target.
|
|
|
|
|
|