mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Debias Estimation loss (#889)
* update for bnb 0.41.1 * fixed generate_controlnet_subsets_config for training * Revert "update for bnb 0.41.1" This reverts commit70bd3612d8. * add debiased_estimation_loss * add train_network * Revert "add train_network" This reverts commit6539363c5c. * Update train_network.py
This commit is contained in:
@@ -34,6 +34,7 @@ from library.custom_train_functions import (
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import library.original_unet as original_unet
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
@@ -471,6 +472,8 @@ def train(args):
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
||||
Reference in New Issue
Block a user