diff --git a/fine_tune.py b/fine_tune.py index c79f97d2..996d1d0e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -34,6 +34,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, @@ -383,6 +384,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index faf44304..2a4a5a3f 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -74,6 +74,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False return loss +def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma))) + loss = loss * soft_min_snr_gamma_weight + return loss + + def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): scale = get_snr_scale(timesteps, noise_scheduler) loss = loss * scale @@ -117,6 +124,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) + parser.add_argument( + "--soft_min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 1 is recommended. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では1が推奨", + ) parser.add_argument( "--scale_v_pred_loss_like_noise_pred", action="store_true", diff --git a/train_controlnet.py b/train_controlnet.py index 6938c4bc..ff855e12 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -33,6 +33,7 @@ import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, pyramid_noise_like, apply_noise_offset, ) @@ -490,6 +491,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index e7cf3cde..585ae24d 100644 --- a/train_db.py +++ b/train_db.py @@ -29,6 +29,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, pyramid_noise_like, @@ -370,6 +371,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/train_network.py b/train_network.py index 6953bb17..eaf7aca3 100644 --- a/train_network.py +++ b/train_network.py @@ -31,6 +31,7 @@ import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, @@ -605,6 +606,7 @@ class NetworkTrainer: "ss_face_crop_aug_range": args.face_crop_aug_range, "ss_prior_loss_weight": args.prior_loss_weight, "ss_min_snr_gamma": args.min_snr_gamma, + "ss_soft_min_snr_gamma": args.soft_min_snr_gamma, "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, "ss_debiased_estimation": bool(args.debiased_estimation_loss), diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 37349da7..3beac841 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -27,6 +27,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, @@ -598,6 +599,8 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: