diff --git a/fine_tune.py b/fine_tune.py index 52921530..3ba63063 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -255,6 +255,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/library/train_util.py b/library/train_util.py index 9125108b..415f9b70 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -300,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.shuffle_keep_tokens is None: if self.shuffle_caption: random.shuffle(tokens) - + tokens = dropout_tags(tokens) else: if len(tokens) > self.shuffle_keep_tokens: @@ -309,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.shuffle_caption: random.shuffle(tokens) - + tokens = dropout_tags(tokens) tokens = keep_tokens + tokens @@ -1102,10 +1102,10 @@ def addnet_hash_safetensors(b): def get_git_revision_hash() -> str: - try: - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() - except: - return "(unknown)" + try: + return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + except: + return "(unknown)" # flash attention forwards and backwards @@ -1421,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--noise_offset", type=float, default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") if support_dreambooth: # DreamBooth training @@ -1653,10 +1655,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # encoder_hidden_states = torch.cat(states_list, dim=1) - + if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) return encoder_hidden_states diff --git a/train_db.py b/train_db.py index c210767b..4a50dc94 100644 --- a/train_db.py +++ b/train_db.py @@ -233,10 +233,13 @@ def train(args): else: latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 + b_size = latents.shape[0] # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): diff --git a/train_network.py b/train_network.py index b783379b..1b8046d2 100644 --- a/train_network.py +++ b/train_network.py @@ -278,9 +278,9 @@ def train(args): # support DistributedDataParallel if type(text_encoder) == DDP: - text_encoder = text_encoder.module - unet = unet.module - network = network.module + text_encoder = text_encoder.module + unet = unet.module + network = network.module network.prepare_grad_etc(text_encoder, unet) @@ -419,6 +419,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4aa91eee..010bd04b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,6 +320,9 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)