Always join with ", "

This commit is contained in:
Yuta Hayashibe
2023-02-06 12:29:41 +09:00
parent ae33d72479
commit 5ea5fefcd2

View File

@@ -135,7 +135,7 @@ class BaseDataset(torch.utils.data.Dataset):
def process_caption(self, caption): def process_caption(self, caption):
if self.shuffle_caption: if self.shuffle_caption:
tokens = caption.strip().split(",") tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None: if self.shuffle_keep_tokens is None:
random.shuffle(tokens) random.shuffle(tokens)
else: else:
@@ -144,7 +144,7 @@ class BaseDataset(torch.utils.data.Dataset):
tokens = tokens[self.shuffle_keep_tokens:] tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens) random.shuffle(tokens)
tokens = keep_tokens + tokens tokens = keep_tokens + tokens
caption = ",".join(tokens).strip() caption = ", ".join(tokens)
for str_from, str_to in self.replacements.items(): for str_from, str_to in self.replacements.items():
if str_from == "": if str_from == "":