Add noise_offset

This commit is contained in:
Kohya S
2023-02-14 21:15:48 +09:00
parent e0f007f2a9
commit 43c0a69843
5 changed files with 27 additions and 13 deletions

View File

@@ -255,6 +255,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)

View File

@@ -300,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.shuffle_keep_tokens is None:
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
@@ -309,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
@@ -1102,10 +1102,10 @@ def addnet_hash_safetensors(b):
def get_git_revision_hash() -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except:
return "(unknown)"
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except:
return "(unknown)"
# flash attention forwards and backwards
@@ -1421,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--noise_offset", type=float, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)")
if support_dreambooth:
# DreamBooth training
@@ -1653,10 +1655,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states

View File

@@ -233,10 +233,13 @@ def train(args):
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):

View File

@@ -278,9 +278,9 @@ def train(args):
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)
@@ -419,6 +419,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)

View File

@@ -320,6 +320,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)