From 8cc81e45f708cd6d9b65f27e150f79be40153204 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 14 Jul 2025 21:20:49 -0400 Subject: [PATCH] Fix wavelet loss on non-flow matching models (sd1.5, SDXL). Fix wavelet coorelation. --- flux_train_network.py | 1 + library/custom_train_functions.py | 357 ++++++++++++++++++++++++++---- library/utils.py | 3 +- sd3_train_network.py | 1 + train_network.py | 25 ++- 5 files changed, 337 insertions(+), 50 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 824c4537..27be4bde 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False + self.is_flow_matching = True def assert_extra_args( self, diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index fa0ad14d..549d4f7b 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -7,11 +7,14 @@ import re import torch import torch.nn as nn import torch.nn.functional as F +import numpy as np from torch import Tensor from torch.types import Number from typing import List, Optional, Union, Protocol from .utils import setup_logging +import matplotlib.pyplot as plt + try: import pywt except: @@ -1064,7 +1067,7 @@ class WaveletLoss(nn.Module): energy_ratio: float = 0.0, energy_scale_factor: float = 0.01, normalize_bands: bool = True, - max_timestep: float = 1.0, + max_timestep: float = 1000, timestep_intensity: float = 0.5, ): """ @@ -1156,13 +1159,10 @@ class WaveletLoss(nn.Module): band_weights = self.band_weights band_level_weights = self.band_level_weights - # Apply timestep-based weighting if provided - # if timestep is not None: - # # Let users control intensity of timestep weighting (0.5 = moderate effect) - # intensity = getattr(self, "timestep_intensity", 0.5) - # current_band_weights, current_band_level_weights = self.noise_aware_weighting( - # timestep, self.max_timestep, intensity=intensity - # ) + base_weight = torch.ones((batch_size), device=device) + if timestep is not None: + base_weight *= self.smooth_timestep_weight(timestep) + metrics['wavelet_loss/avg_timestep_adjusted_weight'] = base_weight.detach().mean().item() # If negative it's from the end of the levels else it's the level. ll_threshold = None @@ -1180,6 +1180,8 @@ class WaveletLoss(nn.Module): continue weight_key = f"{band}{i+1}" + pred = pred_coeffs[band][i] + target = target_coeffs[band][i] if band in pred_coeffs and band in target_coeffs: if self.normalize_bands: @@ -1187,9 +1189,34 @@ class WaveletLoss(nn.Module): pred_coeffs[band][i] = (pred_coeffs[band][i] - pred_coeffs[band][i].mean()) / (pred_coeffs[band][i].std() + 1e-8) target_coeffs[band][i] = (target_coeffs[band][i] - target_coeffs[band][i].mean()) / (target_coeffs[band][i].std() + 1e-8) - weight = band_level_weights.get(weight_key, band_weights[band]) - band_loss = weight * self.loss_fn(pred_coeffs[band][i], target_coeffs[band][i]) - pattern_level_losses += band_loss.mean(dim=0) # mean stack dim + # 1. Magnitude loss + band_loss = self.loss_fn(pred, target) + + # 2. Local structure loss + pred_grad_x = torch.diff(pred, dim=-1) + pred_grad_y = torch.diff(pred, dim=-2) + target_grad_x = torch.diff(target, dim=-1) + target_grad_y = torch.diff(target, dim=-2) + + gradient_loss = F.mse_loss(pred_grad_x, target_grad_x) + \ + F.mse_loss(pred_grad_y, target_grad_y) + + # 3. Global correlation per channel + B, C = pred.shape[:2] + pred_flat = pred.view(B, C, -1) + target_flat = target.view(B, C, -1) + + cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=2) + correlation_loss = (1 - cos_sim).mean() + + weight = base_weight * band_level_weights.get(weight_key, band_weights[band]) + pattern_level_losses += weight.view(-1, 1, 1, 1) * (band_loss + + 0.05 * gradient_loss + + 0.1 * correlation_loss) # mean stack dim + + metrics[f"{band}{i}_band_loss"] = band_loss.detach().mean().item() + metrics[f"{band}{i}_gradient_loss"] = gradient_loss.detach().mean().item() + metrics[f"{band}{i}_correlation_loss"] = correlation_loss.detach().mean().item() # Collect high frequency bands for visualization combined_hf_pred.append(pred_coeffs[band][i]) @@ -1405,37 +1432,33 @@ class WaveletLoss(nn.Module): def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict: """Calculate correlation metrics between prediction and target wavelet coefficients""" metrics = {} - avg_correlations = [] - + for band in ["lh", "hl", "hh"]: - for i in range(1, self.level + 1): - # Get coefficients - pred = pred_coeffs[band][i - 1] - target = target_coeffs[band][i - 1] - - # Flatten for batch-wise correlation - batch_size = pred.shape[0] - pred_flat = pred.view(batch_size, -1) - target_flat = target.view(batch_size, -1) - - # Center data - pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True) - target_centered = target_flat - target_flat.mean(dim=1, keepdim=True) - - # Calculate correlation - numerator = torch.sum(pred_centered * target_centered, dim=1) - denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8) - correlation = numerator / denominator - - # Average across batch - avg_correlation = correlation.mean().item() - metrics[f"{band}{i}_correlation"] = avg_correlation - avg_correlations.append(avg_correlation) - - # Calculate average correlation across all bands - if avg_correlations: - metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations) - + band_correlations = [] + for i in range(self.level): + pred = pred_coeffs[band][i] # [B, C, H, W] + target = target_coeffs[band][i] + + # Flatten spatial dims but keep batch/channel separate + pred_flat = pred.flatten(start_dim=2) # [B, C, H*W] + target_flat = target.flatten(start_dim=2) + + # Calculate correlation across spatial dimension + pred_centered = pred_flat - pred_flat.mean(dim=2, keepdim=True) + target_centered = target_flat - target_flat.mean(dim=2, keepdim=True) + + numerator = torch.sum(pred_centered * target_centered, dim=2) + denom = torch.sqrt(torch.sum(pred_centered**2, dim=2) * + torch.sum(target_centered**2, dim=2) + 1e-8) + + correlation = numerator / denom # [B, C] + avg_corr = correlation.mean().item() + + metrics[f"{band}{i+1}_spatial_correlation"] = avg_corr + band_correlations.append(avg_corr) + + metrics[f"{band}_avg_correlation"] = np.mean(band_correlations) + return metrics @torch.no_grad() @@ -1547,12 +1570,20 @@ class WaveletLoss(nn.Module): # Average sparsity across bands if band_sparsities: - metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities) - if band_non_zero_ratios: # Add this - metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios) + metrics["avg_sparsity_score"] = 1.0 / (sum(band_sparsities) / len(band_sparsities) + 1e-8) + return metrics + def smooth_timestep_weight(self, timestep): + """Smooth weight transition instead of hard cutoff""" + + progress = 1.0 - (timestep / self.max_timestep) + + weight = torch.sigmoid((progress - 0.3) * 10) + + return weight + # TODO: does not work right in terms of weighting in an appropriate range def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0): """ @@ -1680,6 +1711,244 @@ class WaveletLoss(nn.Module): self.loss_fn = loss_fn +def explore_wavelets(coeffs, coeffs_name="Coefficients"): + """Interactive exploration of wavelet coefficients""" + + bands = list(coeffs.keys()) + levels = list(range(len(coeffs[bands[0]]))) + batch_size, n_channels = coeffs[bands[0]][0].shape[:2] + + print(f"\n=== {coeffs_name} Structure ===") + print(f"Bands: {bands}") + print(f"Levels: {levels}") + print(f"Batch size: {batch_size}") + print(f"Channels: {n_channels}") + + for band in bands: + for level in levels: + shape = coeffs[band][level].shape + sparsity = (torch.abs(coeffs[band][level]) < 0.01).float().mean().item() + magnitude = torch.abs(coeffs[band][level]).mean().item() + + print(f"{band.upper()}{level+1}: shape={shape}, " + f"sparsity={sparsity:.1%}, avg_magnitude={magnitude:.4f}") + +# During training, visualize specific coefficients +def visualize_training_wavelets(pred_coeffs, target_coeffs, step): + """Call this during training to save wavelet visualizations""" + + # 1. Visualize predicted coefficients for LH band, level 0 + fig1 = visualize_wavelet_coefficients( + pred_coeffs, band='lh', level=0, batch_idx=0, + title_prefix="Predicted", + save_path=f"wavelets/pred_lh1_step_{step}.png" + ) + plt.close(fig1) + + # 2. Compare predicted vs target + fig2 = compare_wavelet_coefficients( + pred_coeffs, target_coeffs, band='hl', level=1, + batch_idx=0, channel_idx=0, + save_path=f"wavelets/comparison_hl2_step_{step}.png" + ) + plt.close(fig2) + + # 3. Overview of all bands + fig3 = visualize_all_bands_levels( + pred_coeffs, title_prefix="Predicted", batch_idx=0, channel_idx=0, + save_path=f"wavelets/overview_step_{step}.png" + ) + plt.close(fig3) + +def visualize_all_bands_levels(coeffs, title_prefix="", batch_idx=0, + channel_idx=0, save_path=None): + """ + Show all wavelet bands and levels in one overview plot + """ + + bands = ['lh', 'hl', 'hh'] + n_levels = len(coeffs['lh']) # Assuming all bands have same levels + + fig, axes = plt.subplots(len(bands), n_levels, figsize=(4*n_levels, 3*len(bands))) + + if n_levels == 1: + axes = axes.reshape(-1, 1) + + for band_idx, band in enumerate(bands): + for level in range(n_levels): + ax = axes[band_idx, level] + + # Get coefficient data + coeff_data = coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy() + + # Plot + im = ax.imshow(coeff_data, cmap='RdBu_r', aspect='auto') + ax.set_title(f'{band.upper()}{level+1}') + + # Add colorbar for better interpretation + plt.colorbar(im, ax=ax, shrink=0.6) + + # Add sparsity info + sparsity = (np.abs(coeff_data) < 0.01).mean() + ax.text(0.02, 0.02, f'Sparse: {sparsity:.1%}', + transform=ax.transAxes, bbox=dict(boxstyle='round', + facecolor='white', alpha=0.8), fontsize=8) + + fig.suptitle(f'{title_prefix} All Wavelet Bands - Sample {batch_idx}, Channel {channel_idx}', + fontsize=14) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + + return fig + + +def compare_wavelet_coefficients(pred_coeffs, target_coeffs, band, level, + batch_idx=0, channel_idx=0, save_path=None): + """ + Side-by-side comparison of predicted vs target coefficients + """ + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) + + # Get data + pred_data = pred_coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy() + target_data = target_coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy() + + # Calculate difference + diff_data = pred_data - target_data + + # Determine common color scale + vmin = min(pred_data.min(), target_data.min()) + vmax = max(pred_data.max(), target_data.max()) + + # Plot predicted + im1 = ax1.imshow(pred_data, cmap='RdBu_r', vmin=vmin, vmax=vmax) + ax1.set_title(f'Predicted {band.upper()}{level+1} Ch{channel_idx}') + plt.colorbar(im1, ax=ax1, shrink=0.8) + + # Plot target + im2 = ax2.imshow(target_data, cmap='RdBu_r', vmin=vmin, vmax=vmax) + ax2.set_title(f'Target {band.upper()}{level+1} Ch{channel_idx}') + plt.colorbar(im2, ax=ax2, shrink=0.8) + + # Plot difference + im3 = ax3.imshow(diff_data, cmap='RdBu_r', vmin=-np.abs(diff_data).max(), + vmax=np.abs(diff_data).max()) + ax3.set_title('Difference (Pred - Target)') + plt.colorbar(im3, ax=ax3, shrink=0.8) + + # Add correlation info + correlation = np.corrcoef(pred_data.flatten(), target_data.flatten())[0,1] + mse = np.mean((pred_data - target_data)**2) + + fig.suptitle(f'Wavelet Comparison - Correlation: {correlation:.3f}, MSE: {mse:.6f}', + fontsize=14) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + + return fig + +def visualize_wavelet_coefficients(coeffs, band, level, batch_idx=0, + channel_idx=None, title_prefix="", + save_path=None, figsize=(15, 10)): + """ + Visualize wavelet coefficients for a specific band and level + + Args: + coeffs: dict with structure coeffs[band][level] -> [batch, channel, h, w] + band: str, one of ['lh', 'hl', 'hh'] + level: int, wavelet decomposition level (0-indexed) + batch_idx: int, which sample in batch to visualize + channel_idx: int or None, specific channel to show (None = all channels) + title_prefix: str, prefix for plot titles (e.g., "Predicted" or "Target") + save_path: str or None, path to save the plot + figsize: tuple, figure size + + Returns: + fig: matplotlib figure object + """ + + # Extract the specific coefficients + coeff_tensor = coeffs[band][level] # [batch, channel, h, w] + + # Get single sample + sample_coeffs = coeff_tensor[batch_idx] # [channel, h, w] + + batch_size, num_channels, height, width = coeff_tensor.shape + + # Determine which channels to visualize + if channel_idx is not None: + channels_to_show = [channel_idx] + sample_coeffs = sample_coeffs[channel_idx:channel_idx+1] + else: + channels_to_show = list(range(num_channels)) + + # Create subplot layout + n_channels = len(channels_to_show) + cols = min(4, n_channels) # Max 4 columns + rows = (n_channels + cols - 1) // cols # Ceiling division + + fig, axes = plt.subplots(rows, cols, figsize=figsize) + + # Handle single subplot case + if n_channels == 1: + axes = [axes] + elif rows == 1: + axes = [axes] if n_channels == 1 else axes + else: + axes = axes.flatten() + + # Plot each channel + for i, ch_idx in enumerate(channels_to_show): + if i >= len(axes): + break + + ax = axes[i] + + # Get coefficient data for this channel + coeff_data = sample_coeffs[i].detach().cpu().numpy() + + # Create visualization + im = ax.imshow(coeff_data, cmap='RdBu_r', aspect='auto') + + # Add colorbar + plt.colorbar(im, ax=ax, shrink=0.8) + + # Set title + ax.set_title(f'{title_prefix} {band.upper()}{level+1} Ch{ch_idx}\n' + f'Range: [{coeff_data.min():.3f}, {coeff_data.max():.3f}]') + + # Add statistics text + stats_text = f'Mean: {coeff_data.mean():.3f}\n' \ + f'Std: {coeff_data.std():.3f}\n' \ + f'Non-zero: {(np.abs(coeff_data) > 0.01).mean():.1%}' + + ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, + verticalalignment='top', bbox=dict(boxstyle='round', + facecolor='white', alpha=0.8), fontsize=8) + + # Hide unused subplots + for i in range(n_channels, len(axes)): + axes[i].axis('off') + + # Add main title + fig.suptitle(f'{title_prefix} Wavelet Coefficients - {band.upper()} Level {level+1}\n' + f'Sample {batch_idx}, Shape: {coeff_tensor.shape}', + fontsize=14, fontweight='bold') + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + + return fig + def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename): """ Visualize QWT decomposition of input, prediction, and target. diff --git a/library/utils.py b/library/utils.py index b2cd1a01..a5ac531f 100644 --- a/library/utils.py +++ b/library/utils.py @@ -513,7 +513,8 @@ def validate_interpolation_fn(interpolation_str: str) -> bool: # Debugging tool for saving latent as image def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str): with torch.no_grad(): - image = vae.decode(latent_to.to(vae.dtype)).float() + (image,) = vae.decode(latent_to.to(vae.dtype), return_dict=False) + image = image.float() # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] image = (image / 2 + 0.5).clamp(0, 1) diff --git a/sd3_train_network.py b/sd3_train_network.py index cdb7aa4e..4c853de1 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -25,6 +25,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.sample_prompts_te_outputs = None + self.is_flow_matching = True def assert_extra_args( self, diff --git a/train_network.py b/train_network.py index bc860db7..32d6640d 100644 --- a/train_network.py +++ b/train_network.py @@ -57,6 +57,7 @@ class NetworkTrainer: def __init__(self): self.vae_scale_factor = 0.18215 self.is_sdxl = False + self.is_flow_matching = False # TODO 他のスクリプトと共通化する def generate_step_logs( @@ -172,9 +173,9 @@ class NetworkTrainer: train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup], ): - train_dataset_group.verify_bucket_reso_steps(64) + train_dataset_group.verify_bucket_reso_steps(32) if val_dataset_group is not None: - val_dataset_group.verify_bucket_reso_steps(64) + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) @@ -323,6 +324,7 @@ class NetworkTrainer: target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) sigmas = timesteps / noise_scheduler.config.num_train_timesteps + sigmas = sigmas.view(-1, 1, 1, 1) return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise @@ -472,9 +474,22 @@ class NetworkTrainer: if args.wavelet_loss: def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise): if denoise_latents: - # denoise latents to use for wavelet loss - wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) - wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + if self.is_flow_matching: + # denoise latents to use for wavelet loss + wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas) + wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas) + + else: + # Get alpha values from scheduler + alphas_cumprod = noise_scheduler.alphas_cumprod.to(noisy_latents.device) + alpha_t = alphas_cumprod[timesteps].reshape(-1, 1, 1, 1) + sqrt_alpha_t = torch.sqrt(alpha_t) + sqrt_one_minus_alpha_t = torch.sqrt(1.0 - alpha_t) + + # Predict x0 (clean latents) from noise prediction + wavelet_predicted = (noisy_latents - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t + wavelet_target = (noisy_latents - sqrt_one_minus_alpha_t * noise) / sqrt_alpha_t + return wavelet_predicted, wavelet_target else: return noise_pred, target