scale v-pred loss like noise pred

This commit is contained in:
Kohya S
2023-06-03 10:52:22 +09:00
parent 0f0158ddaa
commit ec2efe52e4
7 changed files with 79 additions and 17 deletions

View File

@@ -5,20 +5,37 @@ import re
from typing import List, Optional, Union
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
noise_scheduler.all_snr = all_snr.to(device)
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
loss = loss * snr_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
loss = loss * scale
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -29,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",

View File

@@ -2311,6 +2311,11 @@ def verify_training_args(args: argparse.Namespace):
if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
@@ -3638,4 +3643,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]