From 2c5f5c324a0527623b66d729130894a93c97d651 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 11 Feb 2023 14:41:44 +0900 Subject: [PATCH] Fix crash TI train close #172, tag drop wo shuffle --- library/train_util.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d80ef516..4ec26770 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -226,6 +226,7 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.dropout_rate: float = 0 self.dropout_every_n_epochs: int = None + self.tag_dropout_rate: float = 0 # augmentation flip_p = 0.5 if flip_aug else 0.0 @@ -284,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: - if self.shuffle_caption: + if self.shuffle_caption or self.tag_dropout_rate > 0: def dropout_tags(tokens): if self.tag_dropout_rate <= 0: return tokens @@ -296,13 +297,18 @@ class BaseDataset(torch.utils.data.Dataset): tokens = [t.strip() for t in caption.strip().split(",")] if self.shuffle_keep_tokens is None: - random.shuffle(tokens) + if self.shuffle_caption: + random.shuffle(tokens) + tokens = dropout_tags(tokens) else: if len(tokens) > self.shuffle_keep_tokens: keep_tokens = tokens[:self.shuffle_keep_tokens] tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) + + if self.shuffle_caption: + random.shuffle(tokens) + tokens = dropout_tags(tokens) tokens = keep_tokens + tokens