From 518a18aeff9c2f9a830565e0623729fd938590d3 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Thu, 23 Mar 2023 12:34:49 +0000 Subject: [PATCH] (ACTUAL) Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation --- fine_tune.py | 2 +- library/custom_train_functions.py | 20 ++++++++++++-------- train_db.py | 3 ++- train_network.py | 4 +--- train_textual_inversion.py | 2 +- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index ff33eb9c..45f4b9db 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -306,7 +306,7 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 5e880c9a..b080b40c 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,16 +1,20 @@ import torch import argparse +import numpy as np -def apply_snr_weight(loss, latents, noisy_latents, gamma): - sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment - snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.to(loss.device) + snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper loss = loss * snr_weight - #print(snr_weight) return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): diff --git a/train_db.py b/train_db.py index ee9beda9..52195b92 100644 --- a/train_db.py +++ b/train_db.py @@ -293,7 +293,8 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 715da8c1..145dd600 100644 --- a/train_network.py +++ b/train_network.py @@ -489,7 +489,6 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) - if accelerator.is_main_process: accelerator.init_trackers("network_train") @@ -529,7 +528,6 @@ def train(args): # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -551,7 +549,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5fe662f6..0694dbb6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -381,7 +381,7 @@ def train(args): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights