From d8155bfbe822236306432bf435a92cd719a13e66 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 24 Jan 2024 00:15:07 -0500 Subject: [PATCH] Add soft_min_snr_gamma --- fine_tune.py | 3 +++ library/custom_train_functions.py | 25 +++++++++++++++++++++++++ library/train_util.py | 2 +- train_controlnet.py | 3 +++ train_db.py | 3 +++ train_network.py | 2 ++ train_textual_inversion.py | 3 +++ 7 files changed, 40 insertions(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index 982dc8ae..06f35002 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -27,6 +27,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, @@ -353,6 +354,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index e0a026da..d9800345 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -68,6 +68,25 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False return loss +def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + # min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + soft_min_snr_gamma_weight = 1 / (torch.pow(snr, 2) + (1 / float(gamma))) + with open("snr.txt", "a") as myfile: + myfile.write(f"{snr.item()},{gamma}\n") + + # with open("snrmin.txt", "a") as myfile: + # myfile.write(f"{min_snr_gamma.item()},{soft_min_snr_gamma.item()}\n") + # print("soft_min_snr_gamma", soft_min_snr_gamma, 1 / (snr + (1 / float(gamma)))) + # print("min_snr_gamma", min_snr_gamma) + # if v_prediction: + # snr_weight = torch.div(soft_min_snr_gamma, snr+1).float().to(loss.device) + # else: + # snr_weight = torch.div(soft_min_snr_gamma, snr).float().to(loss.device) + loss = loss * soft_min_snr_gamma_weight + return loss + + def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): scale = get_snr_scale(timesteps, noise_scheduler) loss = loss * scale @@ -106,6 +125,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) + parser.add_argument( + "--soft_min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", + ) parser.add_argument( "--scale_v_pred_loss_like_noise_pred", action="store_true", diff --git a/library/train_util.py b/library/train_util.py index 4ac6728b..6f62176d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4809,7 +4809,7 @@ def sample_images_common( except ImportError: # 事前に一度確認するのでここはエラー出ないはず raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps) except: # wandb 無効時 pass diff --git a/train_controlnet.py b/train_controlnet.py index 7b0b2bbf..aa57cbe9 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -32,6 +32,7 @@ import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, pyramid_noise_like, apply_noise_offset, ) @@ -457,6 +458,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 888cad25..cc8dfba3 100644 --- a/train_db.py +++ b/train_db.py @@ -28,6 +28,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, pyramid_noise_like, @@ -341,6 +342,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/train_network.py b/train_network.py index 8b6c395c..043df616 100644 --- a/train_network.py +++ b/train_network.py @@ -35,6 +35,7 @@ import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, @@ -522,6 +523,7 @@ class NetworkTrainer: "ss_face_crop_aug_range": args.face_crop_aug_range, "ss_prior_loss_weight": args.prior_loss_weight, "ss_min_snr_gamma": args.min_snr_gamma, + "ss_soft_min_snr_gamma": args.soft_min_snr_gamma, "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, "ss_debiased_estimation": bool(args.debiased_estimation_loss), diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 441c1e00..2f25469d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -27,6 +27,7 @@ from library.config_util import ( import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + apply_soft_snr_weight, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, @@ -590,6 +591,8 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.soft_min_snr_gamma: + loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: