add adaptive noise scale

This commit is contained in:
Kohya S
2023-05-07 18:09:08 +09:00
parent e54b6311ef
commit 09c719c926
7 changed files with 45 additions and 20 deletions

View File

@@ -25,7 +25,7 @@ from library.config_util import (
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
# TODO 他のスクリプトと共通化する
@@ -585,11 +585,11 @@ def train(args):
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)