mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add tag dropout
This commit is contained in:
@@ -38,7 +38,7 @@ def train(args):
|
|||||||
args.dataset_repeats, args.debug_dataset)
|
args.dataset_repeats, args.debug_dataset)
|
||||||
|
|
||||||
# 学習データのdropout率を設定する
|
# 学習データのdropout率を設定する
|
||||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
|
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
@@ -230,8 +230,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
train_dataset.epoch_current = epoch + 1
|
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|||||||
@@ -223,8 +223,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
|
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||||
self.epoch_current: int = int(0)
|
|
||||||
self.dropout_rate: float = 0
|
self.dropout_rate: float = 0
|
||||||
self.dropout_every_n_epochs: int = None
|
self.dropout_every_n_epochs: int = None
|
||||||
|
|
||||||
@@ -252,11 +251,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
|
def set_current_epoch(self, epoch):
|
||||||
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
||||||
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
self.dropout_every_n_epochs = dropout_every_n_epochs
|
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||||
|
self.tag_dropout_rate = tag_dropout_rate
|
||||||
|
|
||||||
def set_tag_frequency(self, dir_name, captions):
|
def set_tag_frequency(self, dir_name, captions):
|
||||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||||
@@ -275,18 +277,38 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.replacements[str_from] = str_to
|
self.replacements[str_from] = str_to
|
||||||
|
|
||||||
def process_caption(self, caption):
|
def process_caption(self, caption):
|
||||||
|
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||||
|
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
||||||
|
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
||||||
|
|
||||||
|
if is_drop_out:
|
||||||
|
caption = ""
|
||||||
|
else:
|
||||||
if self.shuffle_caption:
|
if self.shuffle_caption:
|
||||||
|
def dropout_tags(tokens):
|
||||||
|
if self.tag_dropout_rate <= 0:
|
||||||
|
return tokens
|
||||||
|
l = []
|
||||||
|
for token in tokens:
|
||||||
|
if random.random() >= self.tag_dropout_rate:
|
||||||
|
l.append(token)
|
||||||
|
return l
|
||||||
|
|
||||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
if self.shuffle_keep_tokens is None:
|
if self.shuffle_keep_tokens is None:
|
||||||
random.shuffle(tokens)
|
random.shuffle(tokens)
|
||||||
|
tokens = dropout_tags(tokens)
|
||||||
else:
|
else:
|
||||||
if len(tokens) > self.shuffle_keep_tokens:
|
if len(tokens) > self.shuffle_keep_tokens:
|
||||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||||
tokens = tokens[self.shuffle_keep_tokens:]
|
tokens = tokens[self.shuffle_keep_tokens:]
|
||||||
random.shuffle(tokens)
|
random.shuffle(tokens)
|
||||||
|
tokens = dropout_tags(tokens)
|
||||||
|
|
||||||
tokens = keep_tokens + tokens
|
tokens = keep_tokens + tokens
|
||||||
caption = ", ".join(tokens)
|
caption = ", ".join(tokens)
|
||||||
|
|
||||||
|
# textual inversion対応
|
||||||
for str_from, str_to in self.replacements.items():
|
for str_from, str_to in self.replacements.items():
|
||||||
if str_from == "":
|
if str_from == "":
|
||||||
# replace all
|
# replace all
|
||||||
@@ -609,16 +631,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
latents_list.append(latents)
|
latents_list.append(latents)
|
||||||
|
|
||||||
# dropoutの決定
|
|
||||||
is_drop_out = False
|
|
||||||
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
|
|
||||||
is_drop_out = True
|
|
||||||
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
|
|
||||||
is_drop_out = True
|
|
||||||
|
|
||||||
if is_drop_out:
|
|
||||||
caption = ""
|
|
||||||
else:
|
|
||||||
caption = self.process_caption(image_info.caption)
|
caption = self.process_caption(image_info.caption)
|
||||||
captions.append(caption)
|
captions.append(caption)
|
||||||
if not self.token_padding_disabled: # this option might be omitted in future
|
if not self.token_padding_disabled: # this option might be omitted in future
|
||||||
@@ -928,6 +940,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
def debug_dataset(train_dataset, show_input_ids=False):
|
def debug_dataset(train_dataset, show_input_ids=False):
|
||||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("Escape for exit. / Escキーで中断、終了します")
|
||||||
|
|
||||||
|
train_dataset.set_current_epoch(1)
|
||||||
k = 0
|
k = 0
|
||||||
for i, example in enumerate(train_dataset):
|
for i, example in enumerate(train_dataset):
|
||||||
if example['latents'] is not None:
|
if example['latents'] is not None:
|
||||||
@@ -1436,6 +1450,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|||||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||||
|
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
||||||
|
help="Rate out dropout comma seperated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
||||||
|
|
||||||
if support_dreambooth:
|
if support_dreambooth:
|
||||||
# DreamBooth dataset
|
# DreamBooth dataset
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def train(args):
|
|||||||
train_dataset.disable_token_padding()
|
train_dataset.disable_token_padding()
|
||||||
|
|
||||||
# 学習データのdropout率を設定する
|
# 学習データのdropout率を設定する
|
||||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
|
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
@@ -208,8 +208,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
train_dataset.epoch_current = epoch + 1
|
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ def train(args):
|
|||||||
args.dataset_repeats, args.debug_dataset)
|
args.dataset_repeats, args.debug_dataset)
|
||||||
|
|
||||||
# 学習データのdropout率を設定する
|
# 学習データのdropout率を設定する
|
||||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
|
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
@@ -380,8 +380,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
train_dataset.epoch_current = epoch + 1
|
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch+1)
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
|
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ def train(args):
|
|||||||
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
print(len(index_no_updates), torch.sum(index_no_updates))
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
@@ -296,6 +296,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
@@ -383,8 +384,8 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
d = updated_embs - bef_epo_embs
|
# d = updated_embs - bef_epo_embs
|
||||||
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||||
|
|||||||
Reference in New Issue
Block a user