mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add soft_min_snr_gamma
This commit is contained in:
@@ -27,6 +27,7 @@ from library.config_util import (
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
apply_soft_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
@@ -353,6 +354,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.soft_min_snr_gamma:
|
||||
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
|
||||
@@ -68,6 +68,25 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
|
||||
return loss
|
||||
|
||||
|
||||
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
# min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
soft_min_snr_gamma_weight = 1 / (torch.pow(snr, 2) + (1 / float(gamma)))
|
||||
with open("snr.txt", "a") as myfile:
|
||||
myfile.write(f"{snr.item()},{gamma}\n")
|
||||
|
||||
# with open("snrmin.txt", "a") as myfile:
|
||||
# myfile.write(f"{min_snr_gamma.item()},{soft_min_snr_gamma.item()}\n")
|
||||
# print("soft_min_snr_gamma", soft_min_snr_gamma, 1 / (snr + (1 / float(gamma))))
|
||||
# print("min_snr_gamma", min_snr_gamma)
|
||||
# if v_prediction:
|
||||
# snr_weight = torch.div(soft_min_snr_gamma, snr+1).float().to(loss.device)
|
||||
# else:
|
||||
# snr_weight = torch.div(soft_min_snr_gamma, snr).float().to(loss.device)
|
||||
loss = loss * soft_min_snr_gamma_weight
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
loss = loss * scale
|
||||
@@ -106,6 +125,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--soft_min_snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_v_pred_loss_like_noise_pred",
|
||||
action="store_true",
|
||||
|
||||
@@ -4809,7 +4809,7 @@ def sample_images_common(
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps)
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
apply_soft_snr_weight,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
)
|
||||
@@ -457,6 +458,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.soft_min_snr_gamma:
|
||||
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from library.config_util import (
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
apply_soft_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
@@ -341,6 +342,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.soft_min_snr_gamma:
|
||||
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
|
||||
@@ -35,6 +35,7 @@ import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
apply_soft_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
@@ -522,6 +523,7 @@ class NetworkTrainer:
|
||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||
"ss_soft_min_snr_gamma": args.soft_min_snr_gamma,
|
||||
"ss_scale_weight_norms": args.scale_weight_norms,
|
||||
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
||||
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
|
||||
|
||||
@@ -27,6 +27,7 @@ from library.config_util import (
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
apply_soft_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
@@ -590,6 +591,8 @@ class TextualInversionTrainer:
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.soft_min_snr_gamma:
|
||||
loss = apply_soft_snr_weight(loss, timesteps, noise_scheduler, args.soft_min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
|
||||
Reference in New Issue
Block a user