strip tag, fix tag frequency count

This commit is contained in:
Kohya S
2023-03-01 22:10:15 +09:00
parent d1d7d432e9
commit 04af36e7e2

View File

@@ -330,7 +330,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.strip()
if tag:
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
@@ -803,7 +804,8 @@ class DreamBoothDataset(BaseDataset):
captions.append("")
else:
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
return img_paths, captions