This commit is contained in:
Dave Lage
2025-09-30 09:58:53 +05:30
committed by GitHub
6 changed files with 27 additions and 0 deletions

View File

@@ -34,6 +34,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
@@ -383,6 +384,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:

View File

@@ -74,6 +74,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
return loss
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
loss = loss * soft_min_snr_gamma_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
@@ -117,6 +124,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--soft_min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 1 is recommended. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では1が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",

View File

@@ -33,6 +33,7 @@ import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
@@ -490,6 +491,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@@ -29,6 +29,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
@@ -370,6 +371,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:

View File

@@ -31,6 +31,7 @@ import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
@@ -605,6 +606,7 @@ class NetworkTrainer:
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_soft_min_snr_gamma": args.soft_min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),

View File

@@ -27,6 +27,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
@@ -598,6 +599,8 @@ class TextualInversionTrainer:
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: