mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add zero_terminal_snr option
This commit is contained in:
@@ -18,6 +18,42 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
# fix beta: zero terminal SNR
|
||||
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||
|
||||
def enforce_zero_terminal_snr(betas):
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1 - betas
|
||||
alphas_bar = alphas.cumprod(0)
|
||||
alphas_bar_sqrt = alphas_bar.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
# Shift so last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
# Scale so first timestep is back to old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
return betas
|
||||
|
||||
betas = noise_scheduler.betas
|
||||
betas = enforce_zero_terminal_snr(betas)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
# print("original:", noise_scheduler.betas)
|
||||
# print("fixed:", betas)
|
||||
|
||||
noise_scheduler.betas = betas
|
||||
noise_scheduler.alphas = alphas
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
|
||||
@@ -343,9 +343,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
assert (
|
||||
not args.v2 and not args.v_parameterization
|
||||
), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません"
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
|
||||
@@ -2750,6 +2750,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero_terminal_snr",
|
||||
action="store_true",
|
||||
help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_timestep",
|
||||
type=int,
|
||||
@@ -2825,7 +2830,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
if args.v_parameterization and not args.v2:
|
||||
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||
print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
@@ -2856,6 +2861,12 @@ def verify_training_args(args: argparse.Namespace):
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
if args.zero_terminal_snr and not args.v_parameterization:
|
||||
print(
|
||||
f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected"
|
||||
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
|
||||
)
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
|
||||
Reference in New Issue
Block a user