mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix training for V-pred and ztSNR
1) Updates debiased estimation loss function for V-pred. 2) Prevents now-deprecated scaling of loss if ztSNR is enabled.
This commit is contained in:
@@ -383,10 +383,10 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
||||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||||
weight = 1 / torch.sqrt(snr_t)
|
if v_prediction:
|
||||||
|
weight = 1 / (snr_t + 1)
|
||||||
|
else:
|
||||||
|
weight = 1 / torch.sqrt(snr_t)
|
||||||
loss = weight * loss
|
loss = weight * loss
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|||||||
@@ -3732,6 +3732,11 @@ def verify_training_args(args: argparse.Namespace):
|
|||||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr:
|
||||||
|
raise ValueError(
|
||||||
|
"zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません"
|
||||||
|
)
|
||||||
|
|
||||||
if args.v_pred_like_loss and args.v_parameterization:
|
if args.v_pred_like_loss and args.v_parameterization:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
|
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
|
||||||
|
|||||||
@@ -725,12 +725,12 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -474,12 +474,12 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -434,12 +434,12 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -370,10 +370,10 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -993,12 +993,12 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -598,12 +598,12 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -483,10 +483,10 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user