mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Fix wavelet loss on non-flow matching models (sd1.5, SDXL). Fix wavelet coorelation.
This commit is contained in:
@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
self.is_flow_matching = True
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
|
||||
@@ -7,11 +7,14 @@ import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union, Protocol
|
||||
from .utils import setup_logging
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
import pywt
|
||||
except:
|
||||
@@ -1064,7 +1067,7 @@ class WaveletLoss(nn.Module):
|
||||
energy_ratio: float = 0.0,
|
||||
energy_scale_factor: float = 0.01,
|
||||
normalize_bands: bool = True,
|
||||
max_timestep: float = 1.0,
|
||||
max_timestep: float = 1000,
|
||||
timestep_intensity: float = 0.5,
|
||||
):
|
||||
"""
|
||||
@@ -1156,13 +1159,10 @@ class WaveletLoss(nn.Module):
|
||||
band_weights = self.band_weights
|
||||
band_level_weights = self.band_level_weights
|
||||
|
||||
# Apply timestep-based weighting if provided
|
||||
# if timestep is not None:
|
||||
# # Let users control intensity of timestep weighting (0.5 = moderate effect)
|
||||
# intensity = getattr(self, "timestep_intensity", 0.5)
|
||||
# current_band_weights, current_band_level_weights = self.noise_aware_weighting(
|
||||
# timestep, self.max_timestep, intensity=intensity
|
||||
# )
|
||||
base_weight = torch.ones((batch_size), device=device)
|
||||
if timestep is not None:
|
||||
base_weight *= self.smooth_timestep_weight(timestep)
|
||||
metrics['wavelet_loss/avg_timestep_adjusted_weight'] = base_weight.detach().mean().item()
|
||||
|
||||
# If negative it's from the end of the levels else it's the level.
|
||||
ll_threshold = None
|
||||
@@ -1180,6 +1180,8 @@ class WaveletLoss(nn.Module):
|
||||
continue
|
||||
|
||||
weight_key = f"{band}{i+1}"
|
||||
pred = pred_coeffs[band][i]
|
||||
target = target_coeffs[band][i]
|
||||
|
||||
if band in pred_coeffs and band in target_coeffs:
|
||||
if self.normalize_bands:
|
||||
@@ -1187,9 +1189,34 @@ class WaveletLoss(nn.Module):
|
||||
pred_coeffs[band][i] = (pred_coeffs[band][i] - pred_coeffs[band][i].mean()) / (pred_coeffs[band][i].std() + 1e-8)
|
||||
target_coeffs[band][i] = (target_coeffs[band][i] - target_coeffs[band][i].mean()) / (target_coeffs[band][i].std() + 1e-8)
|
||||
|
||||
weight = band_level_weights.get(weight_key, band_weights[band])
|
||||
band_loss = weight * self.loss_fn(pred_coeffs[band][i], target_coeffs[band][i])
|
||||
pattern_level_losses += band_loss.mean(dim=0) # mean stack dim
|
||||
# 1. Magnitude loss
|
||||
band_loss = self.loss_fn(pred, target)
|
||||
|
||||
# 2. Local structure loss
|
||||
pred_grad_x = torch.diff(pred, dim=-1)
|
||||
pred_grad_y = torch.diff(pred, dim=-2)
|
||||
target_grad_x = torch.diff(target, dim=-1)
|
||||
target_grad_y = torch.diff(target, dim=-2)
|
||||
|
||||
gradient_loss = F.mse_loss(pred_grad_x, target_grad_x) + \
|
||||
F.mse_loss(pred_grad_y, target_grad_y)
|
||||
|
||||
# 3. Global correlation per channel
|
||||
B, C = pred.shape[:2]
|
||||
pred_flat = pred.view(B, C, -1)
|
||||
target_flat = target.view(B, C, -1)
|
||||
|
||||
cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=2)
|
||||
correlation_loss = (1 - cos_sim).mean()
|
||||
|
||||
weight = base_weight * band_level_weights.get(weight_key, band_weights[band])
|
||||
pattern_level_losses += weight.view(-1, 1, 1, 1) * (band_loss +
|
||||
0.05 * gradient_loss +
|
||||
0.1 * correlation_loss) # mean stack dim
|
||||
|
||||
metrics[f"{band}{i}_band_loss"] = band_loss.detach().mean().item()
|
||||
metrics[f"{band}{i}_gradient_loss"] = gradient_loss.detach().mean().item()
|
||||
metrics[f"{band}{i}_correlation_loss"] = correlation_loss.detach().mean().item()
|
||||
|
||||
# Collect high frequency bands for visualization
|
||||
combined_hf_pred.append(pred_coeffs[band][i])
|
||||
@@ -1405,37 +1432,33 @@ class WaveletLoss(nn.Module):
|
||||
def calculate_correlation_metrics(self, pred_coeffs: dict[str, list[Tensor]], target_coeffs: dict[str, list[Tensor]]) -> dict:
|
||||
"""Calculate correlation metrics between prediction and target wavelet coefficients"""
|
||||
metrics = {}
|
||||
avg_correlations = []
|
||||
|
||||
|
||||
for band in ["lh", "hl", "hh"]:
|
||||
for i in range(1, self.level + 1):
|
||||
# Get coefficients
|
||||
pred = pred_coeffs[band][i - 1]
|
||||
target = target_coeffs[band][i - 1]
|
||||
|
||||
# Flatten for batch-wise correlation
|
||||
batch_size = pred.shape[0]
|
||||
pred_flat = pred.view(batch_size, -1)
|
||||
target_flat = target.view(batch_size, -1)
|
||||
|
||||
# Center data
|
||||
pred_centered = pred_flat - pred_flat.mean(dim=1, keepdim=True)
|
||||
target_centered = target_flat - target_flat.mean(dim=1, keepdim=True)
|
||||
|
||||
# Calculate correlation
|
||||
numerator = torch.sum(pred_centered * target_centered, dim=1)
|
||||
denominator = torch.sqrt(torch.sum(pred_centered**2, dim=1) * torch.sum(target_centered**2, dim=1) + 1e-8)
|
||||
correlation = numerator / denominator
|
||||
|
||||
# Average across batch
|
||||
avg_correlation = correlation.mean().item()
|
||||
metrics[f"{band}{i}_correlation"] = avg_correlation
|
||||
avg_correlations.append(avg_correlation)
|
||||
|
||||
# Calculate average correlation across all bands
|
||||
if avg_correlations:
|
||||
metrics["avg_correlation"] = sum(avg_correlations) / len(avg_correlations)
|
||||
|
||||
band_correlations = []
|
||||
for i in range(self.level):
|
||||
pred = pred_coeffs[band][i] # [B, C, H, W]
|
||||
target = target_coeffs[band][i]
|
||||
|
||||
# Flatten spatial dims but keep batch/channel separate
|
||||
pred_flat = pred.flatten(start_dim=2) # [B, C, H*W]
|
||||
target_flat = target.flatten(start_dim=2)
|
||||
|
||||
# Calculate correlation across spatial dimension
|
||||
pred_centered = pred_flat - pred_flat.mean(dim=2, keepdim=True)
|
||||
target_centered = target_flat - target_flat.mean(dim=2, keepdim=True)
|
||||
|
||||
numerator = torch.sum(pred_centered * target_centered, dim=2)
|
||||
denom = torch.sqrt(torch.sum(pred_centered**2, dim=2) *
|
||||
torch.sum(target_centered**2, dim=2) + 1e-8)
|
||||
|
||||
correlation = numerator / denom # [B, C]
|
||||
avg_corr = correlation.mean().item()
|
||||
|
||||
metrics[f"{band}{i+1}_spatial_correlation"] = avg_corr
|
||||
band_correlations.append(avg_corr)
|
||||
|
||||
metrics[f"{band}_avg_correlation"] = np.mean(band_correlations)
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1547,12 +1570,20 @@ class WaveletLoss(nn.Module):
|
||||
|
||||
# Average sparsity across bands
|
||||
if band_sparsities:
|
||||
metrics["avg_l1_sparsity"] = sum(band_sparsities) / len(band_sparsities)
|
||||
if band_non_zero_ratios: # Add this
|
||||
metrics["avg_non_zero_ratio"] = sum(band_non_zero_ratios) / len(band_non_zero_ratios)
|
||||
metrics["avg_sparsity_score"] = 1.0 / (sum(band_sparsities) / len(band_sparsities) + 1e-8)
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
def smooth_timestep_weight(self, timestep):
|
||||
"""Smooth weight transition instead of hard cutoff"""
|
||||
|
||||
progress = 1.0 - (timestep / self.max_timestep)
|
||||
|
||||
weight = torch.sigmoid((progress - 0.3) * 10)
|
||||
|
||||
return weight
|
||||
|
||||
# TODO: does not work right in terms of weighting in an appropriate range
|
||||
def noise_aware_weighting(self, timestep: Tensor, max_timestep: float, intensity=1.0):
|
||||
"""
|
||||
@@ -1680,6 +1711,244 @@ class WaveletLoss(nn.Module):
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
|
||||
def explore_wavelets(coeffs, coeffs_name="Coefficients"):
|
||||
"""Interactive exploration of wavelet coefficients"""
|
||||
|
||||
bands = list(coeffs.keys())
|
||||
levels = list(range(len(coeffs[bands[0]])))
|
||||
batch_size, n_channels = coeffs[bands[0]][0].shape[:2]
|
||||
|
||||
print(f"\n=== {coeffs_name} Structure ===")
|
||||
print(f"Bands: {bands}")
|
||||
print(f"Levels: {levels}")
|
||||
print(f"Batch size: {batch_size}")
|
||||
print(f"Channels: {n_channels}")
|
||||
|
||||
for band in bands:
|
||||
for level in levels:
|
||||
shape = coeffs[band][level].shape
|
||||
sparsity = (torch.abs(coeffs[band][level]) < 0.01).float().mean().item()
|
||||
magnitude = torch.abs(coeffs[band][level]).mean().item()
|
||||
|
||||
print(f"{band.upper()}{level+1}: shape={shape}, "
|
||||
f"sparsity={sparsity:.1%}, avg_magnitude={magnitude:.4f}")
|
||||
|
||||
# During training, visualize specific coefficients
|
||||
def visualize_training_wavelets(pred_coeffs, target_coeffs, step):
|
||||
"""Call this during training to save wavelet visualizations"""
|
||||
|
||||
# 1. Visualize predicted coefficients for LH band, level 0
|
||||
fig1 = visualize_wavelet_coefficients(
|
||||
pred_coeffs, band='lh', level=0, batch_idx=0,
|
||||
title_prefix="Predicted",
|
||||
save_path=f"wavelets/pred_lh1_step_{step}.png"
|
||||
)
|
||||
plt.close(fig1)
|
||||
|
||||
# 2. Compare predicted vs target
|
||||
fig2 = compare_wavelet_coefficients(
|
||||
pred_coeffs, target_coeffs, band='hl', level=1,
|
||||
batch_idx=0, channel_idx=0,
|
||||
save_path=f"wavelets/comparison_hl2_step_{step}.png"
|
||||
)
|
||||
plt.close(fig2)
|
||||
|
||||
# 3. Overview of all bands
|
||||
fig3 = visualize_all_bands_levels(
|
||||
pred_coeffs, title_prefix="Predicted", batch_idx=0, channel_idx=0,
|
||||
save_path=f"wavelets/overview_step_{step}.png"
|
||||
)
|
||||
plt.close(fig3)
|
||||
|
||||
def visualize_all_bands_levels(coeffs, title_prefix="", batch_idx=0,
|
||||
channel_idx=0, save_path=None):
|
||||
"""
|
||||
Show all wavelet bands and levels in one overview plot
|
||||
"""
|
||||
|
||||
bands = ['lh', 'hl', 'hh']
|
||||
n_levels = len(coeffs['lh']) # Assuming all bands have same levels
|
||||
|
||||
fig, axes = plt.subplots(len(bands), n_levels, figsize=(4*n_levels, 3*len(bands)))
|
||||
|
||||
if n_levels == 1:
|
||||
axes = axes.reshape(-1, 1)
|
||||
|
||||
for band_idx, band in enumerate(bands):
|
||||
for level in range(n_levels):
|
||||
ax = axes[band_idx, level]
|
||||
|
||||
# Get coefficient data
|
||||
coeff_data = coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy()
|
||||
|
||||
# Plot
|
||||
im = ax.imshow(coeff_data, cmap='RdBu_r', aspect='auto')
|
||||
ax.set_title(f'{band.upper()}{level+1}')
|
||||
|
||||
# Add colorbar for better interpretation
|
||||
plt.colorbar(im, ax=ax, shrink=0.6)
|
||||
|
||||
# Add sparsity info
|
||||
sparsity = (np.abs(coeff_data) < 0.01).mean()
|
||||
ax.text(0.02, 0.02, f'Sparse: {sparsity:.1%}',
|
||||
transform=ax.transAxes, bbox=dict(boxstyle='round',
|
||||
facecolor='white', alpha=0.8), fontsize=8)
|
||||
|
||||
fig.suptitle(f'{title_prefix} All Wavelet Bands - Sample {batch_idx}, Channel {channel_idx}',
|
||||
fontsize=14)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def compare_wavelet_coefficients(pred_coeffs, target_coeffs, band, level,
|
||||
batch_idx=0, channel_idx=0, save_path=None):
|
||||
"""
|
||||
Side-by-side comparison of predicted vs target coefficients
|
||||
"""
|
||||
|
||||
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
|
||||
|
||||
# Get data
|
||||
pred_data = pred_coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy()
|
||||
target_data = target_coeffs[band][level][batch_idx, channel_idx].detach().cpu().numpy()
|
||||
|
||||
# Calculate difference
|
||||
diff_data = pred_data - target_data
|
||||
|
||||
# Determine common color scale
|
||||
vmin = min(pred_data.min(), target_data.min())
|
||||
vmax = max(pred_data.max(), target_data.max())
|
||||
|
||||
# Plot predicted
|
||||
im1 = ax1.imshow(pred_data, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
||||
ax1.set_title(f'Predicted {band.upper()}{level+1} Ch{channel_idx}')
|
||||
plt.colorbar(im1, ax=ax1, shrink=0.8)
|
||||
|
||||
# Plot target
|
||||
im2 = ax2.imshow(target_data, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
||||
ax2.set_title(f'Target {band.upper()}{level+1} Ch{channel_idx}')
|
||||
plt.colorbar(im2, ax=ax2, shrink=0.8)
|
||||
|
||||
# Plot difference
|
||||
im3 = ax3.imshow(diff_data, cmap='RdBu_r', vmin=-np.abs(diff_data).max(),
|
||||
vmax=np.abs(diff_data).max())
|
||||
ax3.set_title('Difference (Pred - Target)')
|
||||
plt.colorbar(im3, ax=ax3, shrink=0.8)
|
||||
|
||||
# Add correlation info
|
||||
correlation = np.corrcoef(pred_data.flatten(), target_data.flatten())[0,1]
|
||||
mse = np.mean((pred_data - target_data)**2)
|
||||
|
||||
fig.suptitle(f'Wavelet Comparison - Correlation: {correlation:.3f}, MSE: {mse:.6f}',
|
||||
fontsize=14)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
|
||||
return fig
|
||||
|
||||
def visualize_wavelet_coefficients(coeffs, band, level, batch_idx=0,
|
||||
channel_idx=None, title_prefix="",
|
||||
save_path=None, figsize=(15, 10)):
|
||||
"""
|
||||
Visualize wavelet coefficients for a specific band and level
|
||||
|
||||
Args:
|
||||
coeffs: dict with structure coeffs[band][level] -> [batch, channel, h, w]
|
||||
band: str, one of ['lh', 'hl', 'hh']
|
||||
level: int, wavelet decomposition level (0-indexed)
|
||||
batch_idx: int, which sample in batch to visualize
|
||||
channel_idx: int or None, specific channel to show (None = all channels)
|
||||
title_prefix: str, prefix for plot titles (e.g., "Predicted" or "Target")
|
||||
save_path: str or None, path to save the plot
|
||||
figsize: tuple, figure size
|
||||
|
||||
Returns:
|
||||
fig: matplotlib figure object
|
||||
"""
|
||||
|
||||
# Extract the specific coefficients
|
||||
coeff_tensor = coeffs[band][level] # [batch, channel, h, w]
|
||||
|
||||
# Get single sample
|
||||
sample_coeffs = coeff_tensor[batch_idx] # [channel, h, w]
|
||||
|
||||
batch_size, num_channels, height, width = coeff_tensor.shape
|
||||
|
||||
# Determine which channels to visualize
|
||||
if channel_idx is not None:
|
||||
channels_to_show = [channel_idx]
|
||||
sample_coeffs = sample_coeffs[channel_idx:channel_idx+1]
|
||||
else:
|
||||
channels_to_show = list(range(num_channels))
|
||||
|
||||
# Create subplot layout
|
||||
n_channels = len(channels_to_show)
|
||||
cols = min(4, n_channels) # Max 4 columns
|
||||
rows = (n_channels + cols - 1) // cols # Ceiling division
|
||||
|
||||
fig, axes = plt.subplots(rows, cols, figsize=figsize)
|
||||
|
||||
# Handle single subplot case
|
||||
if n_channels == 1:
|
||||
axes = [axes]
|
||||
elif rows == 1:
|
||||
axes = [axes] if n_channels == 1 else axes
|
||||
else:
|
||||
axes = axes.flatten()
|
||||
|
||||
# Plot each channel
|
||||
for i, ch_idx in enumerate(channels_to_show):
|
||||
if i >= len(axes):
|
||||
break
|
||||
|
||||
ax = axes[i]
|
||||
|
||||
# Get coefficient data for this channel
|
||||
coeff_data = sample_coeffs[i].detach().cpu().numpy()
|
||||
|
||||
# Create visualization
|
||||
im = ax.imshow(coeff_data, cmap='RdBu_r', aspect='auto')
|
||||
|
||||
# Add colorbar
|
||||
plt.colorbar(im, ax=ax, shrink=0.8)
|
||||
|
||||
# Set title
|
||||
ax.set_title(f'{title_prefix} {band.upper()}{level+1} Ch{ch_idx}\n'
|
||||
f'Range: [{coeff_data.min():.3f}, {coeff_data.max():.3f}]')
|
||||
|
||||
# Add statistics text
|
||||
stats_text = f'Mean: {coeff_data.mean():.3f}\n' \
|
||||
f'Std: {coeff_data.std():.3f}\n' \
|
||||
f'Non-zero: {(np.abs(coeff_data) > 0.01).mean():.1%}'
|
||||
|
||||
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round',
|
||||
facecolor='white', alpha=0.8), fontsize=8)
|
||||
|
||||
# Hide unused subplots
|
||||
for i in range(n_channels, len(axes)):
|
||||
axes[i].axis('off')
|
||||
|
||||
# Add main title
|
||||
fig.suptitle(f'{title_prefix} Wavelet Coefficients - {band.upper()} Level {level+1}\n'
|
||||
f'Sample {batch_idx}, Shape: {coeff_tensor.shape}',
|
||||
fontsize=14, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
|
||||
return fig
|
||||
|
||||
def visualize_qwt_results(qwt_transform, lr_image, pred_latent, target_latent, filename):
|
||||
"""
|
||||
Visualize QWT decomposition of input, prediction, and target.
|
||||
|
||||
@@ -513,7 +513,8 @@ def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
# Debugging tool for saving latent as image
|
||||
def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str):
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_to.to(vae.dtype)).float()
|
||||
(image,) = vae.decode(latent_to.to(vae.dtype), return_dict=False)
|
||||
image = image.float()
|
||||
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_flow_matching = True
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
|
||||
@@ -57,6 +57,7 @@ class NetworkTrainer:
|
||||
def __init__(self):
|
||||
self.vae_scale_factor = 0.18215
|
||||
self.is_sdxl = False
|
||||
self.is_flow_matching = False
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(
|
||||
@@ -172,9 +173,9 @@ class NetworkTrainer:
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -323,6 +324,7 @@ class NetworkTrainer:
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
|
||||
|
||||
@@ -472,9 +474,22 @@ class NetworkTrainer:
|
||||
if args.wavelet_loss:
|
||||
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
|
||||
if denoise_latents:
|
||||
# denoise latents to use for wavelet loss
|
||||
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
|
||||
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
|
||||
if self.is_flow_matching:
|
||||
# denoise latents to use for wavelet loss
|
||||
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
|
||||
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
|
||||
|
||||
else:
|
||||
# Get alpha values from scheduler
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod.to(noisy_latents.device)
|
||||
alpha_t = alphas_cumprod[timesteps].reshape(-1, 1, 1, 1)
|
||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||
sqrt_one_minus_alpha_t = torch.sqrt(1.0 - alpha_t)
|
||||
|
||||
# Predict x0 (clean latents) from noise prediction
|
||||
wavelet_predicted = (noisy_latents - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t
|
||||
wavelet_target = (noisy_latents - sqrt_one_minus_alpha_t * noise) / sqrt_alpha_t
|
||||
|
||||
return wavelet_predicted, wavelet_target
|
||||
else:
|
||||
return noise_pred, target
|
||||
|
||||
Reference in New Issue
Block a user