Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-10-25 19:03:27 +09:00
13 changed files with 16 additions and 18 deletions

View File

@@ -618,7 +618,7 @@ class TextualInversionTrainer:
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし