[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)

* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Kohaku-Blueleaf
2024-03-24 14:40:18 +08:00
committed by GitHub
parent 381c44955e
commit ae97c8bfd1
5 changed files with 173 additions and 22 deletions

View File

@@ -63,6 +63,7 @@ from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
@@ -1080,8 +1081,7 @@ class BaseDataset(torch.utils.data.Dataset):
)
def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
return imagesize.get(image_path)
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
@@ -1425,6 +1425,8 @@ class DreamBoothDataset(BaseDataset):
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
cache_meta: bool,
use_cached_meta: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
@@ -1484,26 +1486,43 @@ class DreamBoothDataset(BaseDataset):
logger.warning(f"not directory: {subset.image_dir}")
return [], []
img_paths = glob_images(subset.image_dir, "*")
sizes = None
if use_cached_meta:
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
# [img_path, caption, resolution]
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
metas = f.readlines()
metas = [x.strip().split("<|##|>") for x in metas]
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
if use_cached_meta:
img_paths = [x[0] for x in metas]
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None]*len(img_paths)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
if use_cached_meta:
captions = [x[1] for x in metas]
missing_captions = [x[0] for x in metas if x[1] == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
@@ -1520,7 +1539,21 @@ class DreamBoothDataset(BaseDataset):
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
break
logger.warning(missing_caption)
return img_paths, captions
if cache_meta:
logger.info(f"cache metadata for {subset.image_dir}")
if sizes is None or sizes[0] is None:
sizes = [self.get_image_size(img_path) for img_path in img_paths]
# [img_path, caption, resolution]
data = [
(img_path, caption, " ".join(str(x) for x in size))
for img_path, caption, size in zip(img_paths, captions, sizes)
]
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
f.write("\n".join(["<|##|>".join(x) for x in data]))
logger.info(f"cache metadata done for {subset.image_dir}")
return img_paths, captions, sizes
logger.info("prepare images.")
num_train_images = 0
@@ -1539,7 +1572,7 @@ class DreamBoothDataset(BaseDataset):
)
continue
img_paths, captions = load_dreambooth_dir(subset)
img_paths, captions, sizes = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
@@ -1551,8 +1584,10 @@ class DreamBoothDataset(BaseDataset):
else:
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
else:
@@ -3355,6 +3390,12 @@ def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
):
# dataset common
parser.add_argument(
"--cache_meta", action="store_true"
)
parser.add_argument(
"--use_cached_meta", action="store_true"
)
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)