Update train_util.py

Added feature to add "." if missing in caption_extension
Added warning on training without captions
This commit is contained in:
TingTingin
2023-05-23 01:57:35 -04:00
committed by GitHub
parent b6ba4cac83
commit 5a1a14f9fc

View File

@@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset):
self.is_reg = is_reg self.is_reg = is_reg
self.class_tokens = class_tokens self.class_tokens = class_tokens
self.caption_extension = caption_extension self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset): if not isinstance(other, DreamBoothSubset):
@@ -1069,7 +1071,7 @@ class DreamBoothDataset(BaseDataset):
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip() caption = lines[0].strip()
break break
return caption return caption
def load_dreambooth_dir(subset: DreamBoothSubset): def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir): if not os.path.isdir(subset.image_dir):
@@ -1081,16 +1083,33 @@ class DreamBoothDataset(BaseDataset):
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = [] captions = []
missing_captions = []
for img_path in img_paths: for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension) cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None: if cap_for_img is None and subset.class_tokens is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
captions.append("") captions.append("")
else: else:
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images")
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption+f"... and {remaining_missing_captions} more")
break
print(missing_caption)
time.sleep(5)
return img_paths, captions return img_paths, captions
print("prepare images.") print("prepare images.")