diff --git a/library/train_util.py b/library/train_util.py index e5b81f1f..47ce81c8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -330,7 +330,8 @@ class BaseDataset(torch.utils.data.Dataset): self.tag_frequency[dir_name] = frequency_for_dir for caption in captions: for tag in caption.split(","): - if tag and not tag.isspace(): + tag = tag.strip() + if tag: tag = tag.lower() frequency = frequency_for_dir.get(tag, 0) frequency_for_dir[tag] = frequency + 1 @@ -803,7 +804,8 @@ class DreamBoothDataset(BaseDataset): captions.append("") else: captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) - self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 return img_paths, captions