add tag dropout

This commit is contained in:
Kohya S
2023-02-09 21:35:27 +09:00
parent f7b5abb595
commit 3a72e6f003
5 changed files with 60 additions and 46 deletions

View File

@@ -223,8 +223,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
self.epoch_current: int = int(0)
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
@@ -252,11 +251,14 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs
self.tag_dropout_rate = tag_dropout_rate
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -275,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements[str_from] = str_to
def process_caption(self, caption):
if self.shuffle_caption:
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
if is_drop_out:
caption = ""
else:
if self.shuffle_caption:
def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= self.tag_dropout_rate:
l.append(token)
return l
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
@@ -609,17 +631,7 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image)
latents_list.append(latents)
# dropoutの決定
is_drop_out = False
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
is_drop_out = True
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
is_drop_out = True
if is_drop_out:
caption = ""
else:
caption = self.process_caption(image_info.caption)
caption = self.process_caption(image_info.caption)
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
@@ -928,6 +940,8 @@ class FineTuningDataset(BaseDataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
train_dataset.set_current_epoch(1)
k = 0
for i, example in enumerate(train_dataset):
if example['latents'] is not None:
@@ -1436,6 +1450,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
help="Rate out dropout comma seperated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
if support_dreambooth:
# DreamBooth dataset