Add noise_offset

This commit is contained in:
Kohya S
2023-02-14 21:15:48 +09:00
parent e0f007f2a9
commit 43c0a69843
5 changed files with 27 additions and 13 deletions

View File

@@ -300,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.shuffle_keep_tokens is None:
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
@@ -309,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
@@ -1102,10 +1102,10 @@ def addnet_hash_safetensors(b):
def get_git_revision_hash() -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except:
return "(unknown)"
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except:
return "(unknown)"
# flash attention forwards and backwards
@@ -1421,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--noise_offset", type=float, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)")
if support_dreambooth:
# DreamBooth training
@@ -1653,10 +1655,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states