add tag dropout

This commit is contained in:
Kohya S
2023-02-09 21:35:27 +09:00
parent f7b5abb595
commit 3a72e6f003
5 changed files with 60 additions and 46 deletions

View File

@@ -38,7 +38,7 @@ def train(args):
args.dataset_repeats, args.debug_dataset) args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する # 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets() train_dataset.make_buckets()
@@ -230,8 +230,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
train_dataset.epoch_current = epoch + 1
for m in training_models: for m in training_models:
m.train() m.train()

View File

@@ -223,8 +223,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.epoch_current: int = int(0)
self.dropout_rate: float = 0 self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None self.dropout_every_n_epochs: int = None
@@ -252,11 +251,14 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs): def set_current_epoch(self, epoch):
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく self.current_epoch = epoch
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく # コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs self.dropout_every_n_epochs = dropout_every_n_epochs
self.tag_dropout_rate = tag_dropout_rate
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -275,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements[str_from] = str_to self.replacements[str_from] = str_to
def process_caption(self, caption): def process_caption(self, caption):
if self.shuffle_caption: # dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
tokens = [t.strip() for t in caption.strip().split(",")] is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
if self.shuffle_keep_tokens is None: is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
random.shuffle(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
for str_from, str_to in self.replacements.items(): if is_drop_out:
if str_from == "": caption = ""
# replace all else:
if type(str_to) == list: if self.shuffle_caption:
caption = random.choice(str_to) def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= self.tag_dropout_rate:
l.append(token)
return l
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else: else:
caption = str_to if len(tokens) > self.shuffle_keep_tokens:
else: keep_tokens = tokens[:self.shuffle_keep_tokens]
caption = caption.replace(str_from, str_to) tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption return caption
@@ -609,17 +631,7 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image) images.append(image)
latents_list.append(latents) latents_list.append(latents)
# dropoutの決定 caption = self.process_caption(image_info.caption)
is_drop_out = False
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
is_drop_out = True
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
is_drop_out = True
if is_drop_out:
caption = ""
else:
caption = self.process_caption(image_info.caption)
captions.append(caption) captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption)) input_ids_list.append(self.get_input_ids(caption))
@@ -928,6 +940,8 @@ class FineTuningDataset(BaseDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("Escape for exit. / Escキーで中断、終了します")
train_dataset.set_current_epoch(1)
k = 0 k = 0
for i, example in enumerate(train_dataset): for i, example in enumerate(train_dataset):
if example['latents'] is not None: if example['latents'] is not None:
@@ -1436,6 +1450,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
help="Rate out dropout comma seperated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
if support_dreambooth: if support_dreambooth:
# DreamBooth dataset # DreamBooth dataset

View File

@@ -43,7 +43,7 @@ def train(args):
train_dataset.disable_token_padding() train_dataset.disable_token_padding()
# 学習データのdropout率を設定する # 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets() train_dataset.make_buckets()
@@ -208,8 +208,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
train_dataset.epoch_current = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train() unet.train()

View File

@@ -134,7 +134,7 @@ def train(args):
args.dataset_repeats, args.debug_dataset) args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する # 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets() train_dataset.make_buckets()
@@ -380,8 +380,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
train_dataset.epoch_current = epoch + 1
metadata["ss_epoch"] = str(epoch+1) metadata["ss_epoch"] = str(epoch+1)

View File

@@ -235,7 +235,7 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler) text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
@@ -296,6 +296,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
text_encoder.train() text_encoder.train()
@@ -383,8 +384,8 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
d = updated_embs - bef_epo_embs # d = updated_embs - bef_epo_embs
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name