Update train_util.py

Added feature to add "." if missing in caption_extension
Added warning on training without captions
This commit is contained in:
TingTingin
2023-05-23 01:57:35 -04:00
committed by GitHub
parent b6ba4cac83
commit 5a1a14f9fc

View File

@@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset):
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
@@ -1081,16 +1083,33 @@ class DreamBoothDataset(BaseDataset):
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
captions.append("")
else:
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images")
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption+f"... and {remaining_missing_captions} more")
break
print(missing_caption)
time.sleep(5)
return img_paths, captions
print("prepare images.")