add adding v-pred like loss for noise pred

This commit is contained in:
Kohya S
2023-07-31 08:23:28 +09:00
parent 7e474d21ca
commit 0636399c8c
5 changed files with 41 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
noise_scheduler.all_snr = all_snr.to(device)
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
@@ -55,6 +56,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
@@ -64,11 +66,24 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
loss = loss * scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
scale = get_snr_scale(timesteps, noise_scheduler)
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
@@ -87,6 +102,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
action="store_true",
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
)
parser.add_argument(
"--v_pred_like_loss",
type=float,
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",

View File

@@ -2868,6 +2868,11 @@ def verify_training_args(args: argparse.Namespace):
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
if args.v_pred_like_loss and args.v_parameterization:
raise ValueError(
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
)
if args.zero_terminal_snr and not args.v_parameterization:
print(