Handle args.v_parameterization properly for MinSNR and changed prediction target

This commit is contained in:
liesen
2024-08-24 01:38:17 +03:00
committed by GitHub
parent 25f961bc77
commit 1e8108fec9

View File

@@ -590,7 +590,11 @@ def train(args):
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
target = noise
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
if (
args.min_snr_gamma
@@ -606,7 +610,7 @@ def train(args):
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: