Fix wavelet loss on non-flow matching models (sd1.5, SDXL). Fix wavelet coorelation.

This commit is contained in:
rockerBOO
2025-07-14 21:20:49 -04:00
parent 8b0a467bc0
commit 8cc81e45f7
5 changed files with 337 additions and 50 deletions

View File

@@ -57,6 +57,7 @@ class NetworkTrainer:
def __init__(self):
self.vae_scale_factor = 0.18215
self.is_sdxl = False
self.is_flow_matching = False
# TODO 他のスクリプトと共通化する
def generate_step_logs(
@@ -172,9 +173,9 @@ class NetworkTrainer:
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
train_dataset_group.verify_bucket_reso_steps(64)
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -323,6 +324,7 @@ class NetworkTrainer:
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
sigmas = sigmas.view(-1, 1, 1, 1)
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
@@ -472,9 +474,22 @@ class NetworkTrainer:
if args.wavelet_loss:
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
if denoise_latents:
# denoise latents to use for wavelet loss
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
if self.is_flow_matching:
# denoise latents to use for wavelet loss
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
else:
# Get alpha values from scheduler
alphas_cumprod = noise_scheduler.alphas_cumprod.to(noisy_latents.device)
alpha_t = alphas_cumprod[timesteps].reshape(-1, 1, 1, 1)
sqrt_alpha_t = torch.sqrt(alpha_t)
sqrt_one_minus_alpha_t = torch.sqrt(1.0 - alpha_t)
# Predict x0 (clean latents) from noise prediction
wavelet_predicted = (noisy_latents - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t
wavelet_target = (noisy_latents - sqrt_one_minus_alpha_t * noise) / sqrt_alpha_t
return wavelet_predicted, wavelet_target
else:
return noise_pred, target