Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-10-25 19:03:27 +09:00
13 changed files with 16 additions and 18 deletions

View File

@@ -733,7 +733,7 @@ def train(args):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else: