mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add noise_offset
This commit is contained in:
@@ -300,7 +300,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if self.shuffle_keep_tokens is None:
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
@@ -309,7 +309,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
@@ -1102,10 +1102,10 @@ def addnet_hash_safetensors(b):
|
||||
|
||||
|
||||
def get_git_revision_hash() -> str:
|
||||
try:
|
||||
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
|
||||
except:
|
||||
return "(unknown)"
|
||||
try:
|
||||
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
|
||||
except:
|
||||
return "(unknown)"
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
@@ -1421,6 +1421,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--noise_offset", type=float, default=None,
|
||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
@@ -1653,10 +1655,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||
# this is required for additional network training
|
||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user