(ACTUAL) Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation

This commit is contained in:
AI-Casanova
2023-03-23 12:34:49 +00:00
parent a3c7d711e4
commit 518a18aeff
5 changed files with 17 additions and 14 deletions

View File

@@ -306,7 +306,7 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:

View File

@@ -1,16 +1,20 @@
import torch import torch
import argparse import argparse
import numpy as np
def apply_snr_weight(loss, latents, noisy_latents, gamma):
sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
zeros = torch.zeros_like(sigma) alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.to(loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper
loss = loss * snr_weight loss = loss * snr_weight
#print(snr_weight)
return loss return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser): def add_custom_train_arguments(parser: argparse.ArgumentParser):

View File

@@ -293,7 +293,8 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@@ -489,7 +489,6 @@ def train(args):
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train")
@@ -529,7 +528,6 @@ def train(args):
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -551,7 +549,7 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@@ -381,7 +381,7 @@ def train(args):
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights