mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset * add cache meta scripts * random ip_noise_gamma strength * random noise_offset strength * use correct settings for parser * cache path/caption/size only * revert mess up commit * revert mess up commit * Update requirements.txt * Add arguments for meta cache. * remove pickle implementation * Return sizes when enable cache --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
103
cache_dataset_meta.py
Normal file
103
cache_dataset_meta.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
import library.config_util as config_util
|
||||||
|
from library.config_util import (
|
||||||
|
ConfigSanitizer,
|
||||||
|
BlueprintGenerator,
|
||||||
|
)
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
|
from library.utils import setup_logging, add_logging_arguments
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_dataset(args):
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
|
use_dreambooth_method = args.in_json is None
|
||||||
|
use_user_config = args.dataset_config is not None
|
||||||
|
|
||||||
|
if args.seed is None:
|
||||||
|
args.seed = random.randint(0, 2**32)
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
|
blueprint_generator = BlueprintGenerator(
|
||||||
|
ConfigSanitizer(True, True, False, True)
|
||||||
|
)
|
||||||
|
if use_user_config:
|
||||||
|
logger.info(f"Loading dataset config from {args.dataset_config}")
|
||||||
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
|
logger.warning(
|
||||||
|
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
|
", ".join(ignored)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if use_dreambooth_method:
|
||||||
|
logger.info("Using DreamBooth method.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||||
|
args.train_data_dir, args.reg_data_dir
|
||||||
|
)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.info("Training with captions.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
|
||||||
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
|
||||||
|
blueprint.dataset_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# use arbitrary dataset class
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
|
||||||
|
return train_dataset_group
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_logging_arguments(parser)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
train_util.add_training_arguments(parser, True)
|
||||||
|
config_util.add_config_arguments(parser)
|
||||||
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
if args.max_token_length is None:
|
||||||
|
args.max_token_length = 75
|
||||||
|
args.cache_meta = True
|
||||||
|
|
||||||
|
dataset_group = make_dataset(args)
|
||||||
@@ -111,6 +111,8 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
|||||||
bucket_reso_steps: int = 64
|
bucket_reso_steps: int = 64
|
||||||
bucket_no_upscale: bool = False
|
bucket_no_upscale: bool = False
|
||||||
prior_loss_weight: float = 1.0
|
prior_loss_weight: float = 1.0
|
||||||
|
cache_meta: bool = False
|
||||||
|
use_cached_meta: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -228,6 +230,8 @@ class ConfigSanitizer:
|
|||||||
"min_bucket_reso": int,
|
"min_bucket_reso": int,
|
||||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||||
"network_multiplier": float,
|
"network_multiplier": float,
|
||||||
|
"cache_meta": bool,
|
||||||
|
"use_cached_meta": bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
# options handled by argparse but not handled by user config
|
# options handled by argparse but not handled by user config
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from library.original_unet import UNet2DConditionModel
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import imagesize
|
||||||
import cv2
|
import cv2
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
@@ -1080,8 +1081,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_image_size(self, image_path):
|
def get_image_size(self, image_path):
|
||||||
image = Image.open(image_path)
|
return imagesize.get(image_path)
|
||||||
return image.size
|
|
||||||
|
|
||||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||||
img = load_image(image_path)
|
img = load_image(image_path)
|
||||||
@@ -1425,6 +1425,8 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
bucket_no_upscale: bool,
|
bucket_no_upscale: bool,
|
||||||
prior_loss_weight: float,
|
prior_loss_weight: float,
|
||||||
debug_dataset: bool,
|
debug_dataset: bool,
|
||||||
|
cache_meta: bool,
|
||||||
|
use_cached_meta: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||||
|
|
||||||
@@ -1484,26 +1486,43 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
logger.warning(f"not directory: {subset.image_dir}")
|
logger.warning(f"not directory: {subset.image_dir}")
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
img_paths = glob_images(subset.image_dir, "*")
|
sizes = None
|
||||||
|
if use_cached_meta:
|
||||||
|
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
|
||||||
|
# [img_path, caption, resolution]
|
||||||
|
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
|
||||||
|
metas = f.readlines()
|
||||||
|
metas = [x.strip().split("<|##|>") for x in metas]
|
||||||
|
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
|
||||||
|
|
||||||
|
if use_cached_meta:
|
||||||
|
img_paths = [x[0] for x in metas]
|
||||||
|
else:
|
||||||
|
img_paths = glob_images(subset.image_dir, "*")
|
||||||
|
sizes = [None]*len(img_paths)
|
||||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||||
|
|
||||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
if use_cached_meta:
|
||||||
captions = []
|
captions = [x[1] for x in metas]
|
||||||
missing_captions = []
|
missing_captions = [x[0] for x in metas if x[1] == ""]
|
||||||
for img_path in img_paths:
|
else:
|
||||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||||
if cap_for_img is None and subset.class_tokens is None:
|
captions = []
|
||||||
logger.warning(
|
missing_captions = []
|
||||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
for img_path in img_paths:
|
||||||
)
|
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||||
captions.append("")
|
if cap_for_img is None and subset.class_tokens is None:
|
||||||
missing_captions.append(img_path)
|
logger.warning(
|
||||||
else:
|
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||||
if cap_for_img is None:
|
)
|
||||||
captions.append(subset.class_tokens)
|
captions.append("")
|
||||||
missing_captions.append(img_path)
|
missing_captions.append(img_path)
|
||||||
else:
|
else:
|
||||||
captions.append(cap_for_img)
|
if cap_for_img is None:
|
||||||
|
captions.append(subset.class_tokens)
|
||||||
|
missing_captions.append(img_path)
|
||||||
|
else:
|
||||||
|
captions.append(cap_for_img)
|
||||||
|
|
||||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||||
|
|
||||||
@@ -1520,7 +1539,21 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
|
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
|
||||||
break
|
break
|
||||||
logger.warning(missing_caption)
|
logger.warning(missing_caption)
|
||||||
return img_paths, captions
|
|
||||||
|
if cache_meta:
|
||||||
|
logger.info(f"cache metadata for {subset.image_dir}")
|
||||||
|
if sizes is None or sizes[0] is None:
|
||||||
|
sizes = [self.get_image_size(img_path) for img_path in img_paths]
|
||||||
|
# [img_path, caption, resolution]
|
||||||
|
data = [
|
||||||
|
(img_path, caption, " ".join(str(x) for x in size))
|
||||||
|
for img_path, caption, size in zip(img_paths, captions, sizes)
|
||||||
|
]
|
||||||
|
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(["<|##|>".join(x) for x in data]))
|
||||||
|
logger.info(f"cache metadata done for {subset.image_dir}")
|
||||||
|
|
||||||
|
return img_paths, captions, sizes
|
||||||
|
|
||||||
logger.info("prepare images.")
|
logger.info("prepare images.")
|
||||||
num_train_images = 0
|
num_train_images = 0
|
||||||
@@ -1539,7 +1572,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img_paths, captions = load_dreambooth_dir(subset)
|
img_paths, captions, sizes = load_dreambooth_dir(subset)
|
||||||
if len(img_paths) < 1:
|
if len(img_paths) < 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
||||||
@@ -1551,8 +1584,10 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
else:
|
else:
|
||||||
num_train_images += subset.num_repeats * len(img_paths)
|
num_train_images += subset.num_repeats * len(img_paths)
|
||||||
|
|
||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||||
|
if size is not None:
|
||||||
|
info.image_size = size
|
||||||
if subset.is_reg:
|
if subset.is_reg:
|
||||||
reg_infos.append((info, subset))
|
reg_infos.append((info, subset))
|
||||||
else:
|
else:
|
||||||
@@ -3355,6 +3390,12 @@ def add_dataset_arguments(
|
|||||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||||
):
|
):
|
||||||
# dataset common
|
# dataset common
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_meta", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_cached_meta", action="store_true"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
|
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ easygui==0.98.3
|
|||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
huggingface-hub==0.20.1
|
huggingface-hub==0.20.1
|
||||||
|
# for Image utils
|
||||||
|
imagesize==1.4.1
|
||||||
# for BLIP captioning
|
# for BLIP captioning
|
||||||
# requests==2.28.2
|
# requests==2.28.2
|
||||||
# timm==0.6.12
|
# timm==0.6.12
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import sys
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ from library import model_util
|
|||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import (
|
from library.train_util import (
|
||||||
DreamBoothDataset,
|
DreamBoothDataset, DatasetGroup
|
||||||
)
|
)
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
|
|||||||
Reference in New Issue
Block a user