Add soft_min_snr_gamma

This commit is contained in:
rockerBOO
2024-01-24 00:15:07 -05:00
parent d5ab97b69b
commit d8155bfbe8
7 changed files with 40 additions and 1 deletions

View File

@@ -27,6 +27,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
@@ -353,6 +354,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:

View File

@@ -68,6 +68,25 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
return loss
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
# min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
soft_min_snr_gamma_weight = 1 / (torch.pow(snr, 2) + (1 / float(gamma)))
with open("snr.txt", "a") as myfile:
myfile.write(f"{snr.item()},{gamma}\n")
# with open("snrmin.txt", "a") as myfile:
# myfile.write(f"{min_snr_gamma.item()},{soft_min_snr_gamma.item()}\n")
# print("soft_min_snr_gamma", soft_min_snr_gamma, 1 / (snr + (1 / float(gamma))))
# print("min_snr_gamma", min_snr_gamma)
# if v_prediction:
# snr_weight = torch.div(soft_min_snr_gamma, snr+1).float().to(loss.device)
# else:
# snr_weight = torch.div(soft_min_snr_gamma, snr).float().to(loss.device)
loss = loss * soft_min_snr_gamma_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
@@ -106,6 +125,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--soft_min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",

View File

@@ -4809,7 +4809,7 @@ def sample_images_common(
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps)
except: # wandb 無効時
pass

View File

@@ -32,6 +32,7 @@ import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
@@ -457,6 +458,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@@ -28,6 +28,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
@@ -341,6 +342,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:

View File

@@ -35,6 +35,7 @@ import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
@@ -522,6 +523,7 @@ class NetworkTrainer:
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_soft_min_snr_gamma": args.soft_min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),

View File

@@ -27,6 +27,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
apply_soft_snr_weight,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
@@ -590,6 +591,8 @@ class TextualInversionTrainer:
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.soft_min_snr_gamma:
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: