diff --git a/library/train_util.py b/library/train_util.py index 6f809deb..cdd5860a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -265,7 +265,7 @@ class BaseDataset(torch.utils.data.Dataset): def process_caption(self, caption): if self.shuffle_caption: - tokens = caption.strip().split(",") + tokens = [t.strip() for t in caption.strip().split(",")] if self.shuffle_keep_tokens is None: random.shuffle(tokens) else: @@ -274,7 +274,7 @@ class BaseDataset(torch.utils.data.Dataset): tokens = tokens[self.shuffle_keep_tokens:] random.shuffle(tokens) tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() + caption = ", ".join(tokens) for str_from, str_to in self.replacements.items(): if str_from == "":