mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset * add cache meta scripts * random ip_noise_gamma strength * random noise_offset strength * use correct settings for parser * cache path/caption/size only * revert mess up commit * revert mess up commit * Update requirements.txt * Add arguments for meta cache. * remove pickle implementation * Return sizes when enable cache --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -63,6 +63,7 @@ from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import imagesize
|
||||
import cv2
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
@@ -1080,8 +1081,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
return image.size
|
||||
return imagesize.get(image_path)
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||
img = load_image(image_path)
|
||||
@@ -1425,6 +1425,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
cache_meta: bool,
|
||||
use_cached_meta: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -1484,26 +1486,43 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = None
|
||||
if use_cached_meta:
|
||||
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
|
||||
# [img_path, caption, resolution]
|
||||
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
|
||||
metas = f.readlines()
|
||||
metas = [x.strip().split("<|##|>") for x in metas]
|
||||
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
|
||||
|
||||
if use_cached_meta:
|
||||
img_paths = [x[0] for x in metas]
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = [None]*len(img_paths)
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
if use_cached_meta:
|
||||
captions = [x[1] for x in metas]
|
||||
missing_captions = [x[0] for x in metas if x[1] == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
@@ -1520,7 +1539,21 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
logger.warning(missing_caption)
|
||||
return img_paths, captions
|
||||
|
||||
if cache_meta:
|
||||
logger.info(f"cache metadata for {subset.image_dir}")
|
||||
if sizes is None or sizes[0] is None:
|
||||
sizes = [self.get_image_size(img_path) for img_path in img_paths]
|
||||
# [img_path, caption, resolution]
|
||||
data = [
|
||||
(img_path, caption, " ".join(str(x) for x in size))
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes)
|
||||
]
|
||||
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(["<|##|>".join(x) for x in data]))
|
||||
logger.info(f"cache metadata done for {subset.image_dir}")
|
||||
|
||||
return img_paths, captions, sizes
|
||||
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
@@ -1539,7 +1572,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, captions = load_dreambooth_dir(subset)
|
||||
img_paths, captions, sizes = load_dreambooth_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
||||
@@ -1551,8 +1584,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
@@ -3355,6 +3390,12 @@ def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
):
|
||||
# dataset common
|
||||
parser.add_argument(
|
||||
"--cache_meta", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cached_meta", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user