diff --git a/library/train_util.py b/library/train_util.py index 41afc13b..05ec7f84 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset): self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension def __eq__(self, other) -> bool: if not isinstance(other, DreamBoothSubset): @@ -1069,7 +1071,7 @@ class DreamBoothDataset(BaseDataset): assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break - return caption + return caption def load_dreambooth_dir(subset: DreamBoothSubset): if not os.path.isdir(subset.image_dir): @@ -1081,16 +1083,33 @@ class DreamBoothDataset(BaseDataset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] + missing_captions = [] for img_path in img_paths: cap_for_img = read_caption(img_path, subset.caption_extension) if cap_for_img is None and subset.class_tokens is None: print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") captions.append("") else: - captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + if cap_for_img is None: + captions.append(subset.class_tokens) + missing_captions.append(img_path) + else: + captions.append(cap_for_img) self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + if missing_captions: + number_of_missing_captions = len(missing_captions) + number_of_missing_captions_to_show = 5 + remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show + + print(f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images") + for i, missing_caption in enumerate(missing_captions): + if i >= number_of_missing_captions_to_show: + print(missing_caption+f"... and {remaining_missing_captions} more") + break + print(missing_caption) + time.sleep(5) return img_paths, captions print("prepare images.")