mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Debias Estimation loss (#889)
* update for bnb 0.41.1 * fixed generate_controlnet_subsets_config for training * Revert "update for bnb 0.41.1" This reverts commit70bd3612d8. * add debiased_estimation_loss * add train_network * Revert "add train_network" This reverts commit6539363c5c. * Update train_network.py
This commit is contained in:
@@ -32,6 +32,7 @@ from library.custom_train_functions import (
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
@@ -339,7 +340,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -348,6 +349,8 @@ def train(args):
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user