diff --git a/library/train_util.py b/library/train_util.py index 013cc81c..b249e61d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1200,19 +1200,27 @@ class FineTuningDataset(BaseDataset): tags_list = [] for image_key, img_md in metadata.items(): # path情報を作る + abs_path = None + + # まず画像を優先して探す if os.path.exists(image_key): abs_path = image_key - elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, image_key) + if len(paths) > 0: + abs_path = paths[0] + + # なければnpzを探す + if abs_path is None: + if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + + assert abs_path is not None, f"no image / 画像がありません: {image_key}" caption = img_md.get("caption") tags = img_md.get("tags")