mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
(ACTUAL) Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation
This commit is contained in:
@@ -306,7 +306,7 @@ def train(args):
|
|||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def apply_snr_weight(loss, latents, noisy_latents, gamma):
|
|
||||||
sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler
|
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||||
zeros = torch.zeros_like(sigma)
|
alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
|
||||||
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment
|
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
||||||
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment
|
sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
|
||||||
snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares
|
alpha = sqrt_alphas_cumprod
|
||||||
|
sigma = sqrt_one_minus_alphas_cumprod
|
||||||
|
all_snr = (alpha / sigma) ** 2
|
||||||
|
all_snr.to(loss.device)
|
||||||
|
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||||
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
|
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
|
||||||
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
|
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper
|
||||||
loss = loss * snr_weight
|
loss = loss * snr_weight
|
||||||
#print(snr_weight)
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def add_custom_train_arguments(parser: argparse.ArgumentParser):
|
def add_custom_train_arguments(parser: argparse.ArgumentParser):
|
||||||
|
|||||||
@@ -293,7 +293,8 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -489,7 +489,6 @@ def train(args):
|
|||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler(
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("network_train")
|
accelerator.init_trackers("network_train")
|
||||||
|
|
||||||
@@ -529,7 +528,6 @@ def train(args):
|
|||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
@@ -551,7 +549,7 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -381,7 +381,7 @@ def train(args):
|
|||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|||||||
Reference in New Issue
Block a user