mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add tag dropout
This commit is contained in:
@@ -223,8 +223,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
|
||||
self.epoch_current: int = int(0)
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
|
||||
@@ -252,11 +251,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
|
||||
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
||||
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||
self.tag_dropout_rate = tag_dropout_rate
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
@@ -275,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
random.shuffle(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ", ".join(tokens)
|
||||
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
||||
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
||||
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if self.shuffle_caption:
|
||||
def dropout_tags(tokens):
|
||||
if self.tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
l = []
|
||||
for token in tokens:
|
||||
if random.random() >= self.tag_dropout_rate:
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
random.shuffle(tokens)
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
random.shuffle(tokens)
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ", ".join(tokens)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
|
||||
return caption
|
||||
|
||||
@@ -609,17 +631,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
# dropoutの決定
|
||||
is_drop_out = False
|
||||
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
|
||||
is_drop_out = True
|
||||
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
|
||||
is_drop_out = True
|
||||
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
caption = self.process_caption(image_info.caption)
|
||||
caption = self.process_caption(image_info.caption)
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
input_ids_list.append(self.get_input_ids(caption))
|
||||
@@ -928,6 +940,8 @@ class FineTuningDataset(BaseDataset):
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
|
||||
train_dataset.set_current_epoch(1)
|
||||
k = 0
|
||||
for i, example in enumerate(train_dataset):
|
||||
if example['latents'] is not None:
|
||||
@@ -1436,6 +1450,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout comma seperated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth dataset
|
||||
|
||||
Reference in New Issue
Block a user