Merge pull request #533 from TingTingin/main

Added warning on training without captions
This commit is contained in:
Kohya S
2023-05-29 08:37:33 +09:00
committed by GitHub

View File

@@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset):
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
@@ -1081,16 +1083,32 @@ class DreamBoothDataset(BaseDataset):
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
captions.append("")
else:
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images")
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption+f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, captions
print("prepare images.")