diff --git a/fine_tune.py b/fine_tune.py index e743a349..52921530 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -38,7 +38,7 @@ def train(args): args.dataset_repeats, args.debug_dataset) # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) train_dataset.make_buckets() @@ -230,8 +230,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - - train_dataset.epoch_current = epoch + 1 + train_dataset.set_current_epoch(epoch + 1) for m in training_models: m.train() diff --git a/library/train_util.py b/library/train_util.py index 612eba2d..1f92af43 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -223,8 +223,7 @@ class BaseDataset(torch.utils.data.Dataset): self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - # TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう - self.epoch_current: int = int(0) + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.dropout_rate: float = 0 self.dropout_every_n_epochs: int = None @@ -252,11 +251,14 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} - def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs): - # 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate): # コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく) self.dropout_rate = dropout_rate self.dropout_every_n_epochs = dropout_every_n_epochs + self.tag_dropout_rate = tag_dropout_rate def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) @@ -275,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements[str_from] = str_to def process_caption(self, caption): - if self.shuffle_caption: - tokens = [t.strip() for t in caption.strip().split(",")] - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ", ".join(tokens) + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate + is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) + if is_drop_out: + caption = "" + else: + if self.shuffle_caption: + def dropout_tags(tokens): + if self.tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= self.tag_dropout_rate: + l.append(token) + return l + + tokens = [t.strip() for t in caption.strip().split(",")] + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + tokens = dropout_tags(tokens) else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = dropout_tags(tokens) + + tokens = keep_tokens + tokens + caption = ", ".join(tokens) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) return caption @@ -609,17 +631,7 @@ class BaseDataset(torch.utils.data.Dataset): images.append(image) latents_list.append(latents) - # dropoutの決定 - is_drop_out = False - if self.dropout_rate > 0 and random.random() < self.dropout_rate: - is_drop_out = True - if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0: - is_drop_out = True - - if is_drop_out: - caption = "" - else: - caption = self.process_caption(image_info.caption) + caption = self.process_caption(image_info.caption) captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) @@ -928,6 +940,8 @@ class FineTuningDataset(BaseDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") + + train_dataset.set_current_epoch(1) k = 0 for i, example in enumerate(train_dataset): if example['latents'] is not None: @@ -1436,6 +1450,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None, help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") + parser.add_argument("--caption_tag_dropout_rate", type=float, default=0, + help="Rate out dropout comma seperated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") if support_dreambooth: # DreamBooth dataset diff --git a/train_db.py b/train_db.py index 51f5038b..c210767b 100644 --- a/train_db.py +++ b/train_db.py @@ -43,7 +43,7 @@ def train(args): train_dataset.disable_token_padding() # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) train_dataset.make_buckets() @@ -208,8 +208,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - - train_dataset.epoch_current = epoch + 1 + train_dataset.set_current_epoch(epoch + 1) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() diff --git a/train_network.py b/train_network.py index f3ca417c..bb3159fd 100644 --- a/train_network.py +++ b/train_network.py @@ -134,7 +134,7 @@ def train(args): args.dataset_repeats, args.debug_dataset) # 学習データのdropout率を設定する - train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs) + train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) train_dataset.make_buckets() @@ -380,8 +380,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - - train_dataset.epoch_current = epoch + 1 + train_dataset.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d3e558a3..ba2e7145 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -235,7 +235,7 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler) index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - print(len(index_no_updates), torch.sum(index_no_updates)) + # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -296,6 +296,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset.set_current_epoch(epoch + 1) text_encoder.train() @@ -383,8 +384,8 @@ def train(args): accelerator.wait_for_everyone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() - d = updated_embs - bef_epo_embs - print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) + # d = updated_embs - bef_epo_embs + # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) if args.save_every_n_epochs is not None: model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name