Fix crash TI train close #172, tag drop wo shuffle

This commit is contained in:
Kohya S
2023-02-11 14:41:44 +09:00
parent 5777be5208
commit 2c5f5c324a

View File

@@ -226,6 +226,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.tag_dropout_rate: float = 0
# augmentation
flip_p = 0.5 if flip_aug else 0.0
@@ -284,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
if self.shuffle_caption:
if self.shuffle_caption or self.tag_dropout_rate > 0:
def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
@@ -296,13 +297,18 @@ class BaseDataset(torch.utils.data.Dataset):
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens