refactor metadata caching for DreamBooth dataset

This commit is contained in:
Kohya S
2024-03-24 18:09:32 +09:00
parent ae97c8bfd1
commit 025347214d
6 changed files with 85 additions and 159 deletions

View File

@@ -410,6 +410,7 @@ class DreamBoothSubset(BaseSubset):
is_reg: bool,
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator: str,
@@ -458,6 +459,7 @@ class DreamBoothSubset(BaseSubset):
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
@@ -527,6 +529,7 @@ class ControlNetSubset(BaseSubset):
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator,
@@ -574,6 +577,7 @@ class ControlNetSubset(BaseSubset):
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
@@ -1410,6 +1414,8 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
@@ -1425,8 +1431,6 @@ class DreamBoothDataset(BaseDataset):
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
cache_meta: bool,
use_cached_meta: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
@@ -1486,25 +1490,36 @@ class DreamBoothDataset(BaseDataset):
logger.warning(f"not directory: {subset.image_dir}")
return [], []
sizes = None
if use_cached_meta:
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
# [img_path, caption, resolution]
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
metas = f.readlines()
metas = [x.strip().split("<|##|>") for x in metas]
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
if use_cached_meta:
img_paths = [x[0] for x in metas]
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
use_cached_info_for_subset = subset.cache_info
if use_cached_info_for_subset:
logger.info(
f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}"
)
if not os.path.isfile(info_cache_file):
logger.warning(
f"image info file not found. You can ignore this warning if this is the first time to use this subset"
+ " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}"
)
use_cached_info_for_subset = False
if use_cached_info_for_subset:
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
with open(info_cache_file, "r", encoding="utf-8") as f:
metas = json.load(f)
img_paths = list(metas.keys())
sizes = [meta["resolution"] for meta in metas.values()]
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None]*len(img_paths)
sizes = [None] * len(img_paths)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_meta:
captions = [x[1] for x in metas]
missing_captions = [x[0] for x in metas if x[1] == ""]
if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
@@ -1540,19 +1555,17 @@ class DreamBoothDataset(BaseDataset):
break
logger.warning(missing_caption)
if cache_meta:
logger.info(f"cache metadata for {subset.image_dir}")
if sizes is None or sizes[0] is None:
sizes = [self.get_image_size(img_path) for img_path in img_paths]
# [img_path, caption, resolution]
data = [
(img_path, caption, " ".join(str(x) for x in size))
for img_path, caption, size in zip(img_paths, captions, sizes)
]
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
f.write("\n".join(["<|##|>".join(x) for x in data]))
logger.info(f"cache metadata done for {subset.image_dir}")
if not use_cached_info_for_subset and subset.cache_info:
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
matas = {}
for img_path, caption, size in zip(img_paths, captions, sizes):
matas[img_path] = {"caption": caption, "resolution": list(size)}
with open(info_cache_file, "w", encoding="utf-8") as f:
json.dump(matas, f, ensure_ascii=False, indent=2)
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
# if sizes are not set, image size will be read in make_buckets
return img_paths, captions, sizes
logger.info("prepare images.")
@@ -1873,7 +1886,8 @@ class ControlNetDataset(BaseDataset):
subset.image_dir,
False,
None,
subset.caption_extension,
subset.caption_extension,
subset.cache_info,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
@@ -3390,15 +3404,15 @@ def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
):
# dataset common
parser.add_argument(
"--cache_meta", action="store_true"
)
parser.add_argument(
"--use_cached_meta", action="store_true"
)
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--cache_info",
action="store_true",
help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth"
+ " / メタ情報キャプションとサイズをキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効",
)
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
)