mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor metadata caching for DreamBooth dataset
This commit is contained in:
@@ -41,12 +41,17 @@ from .train_util import (
|
||||
DatasetGroup,
|
||||
)
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
||||
parser.add_argument(
|
||||
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
||||
)
|
||||
|
||||
|
||||
# TODO: inherit Params class in Subset, Dataset
|
||||
@@ -80,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
is_reg: bool = False
|
||||
class_tokens: Optional[str] = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -91,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams):
|
||||
class ControlNetSubsetParams(BaseSubsetParams):
|
||||
conditioning_data_dir: str = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -111,8 +118,6 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
cache_meta: bool = False
|
||||
use_cached_meta: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -202,6 +207,7 @@ class ConfigSanitizer:
|
||||
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
"class_tokens": str,
|
||||
"cache_info": bool,
|
||||
}
|
||||
DB_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
@@ -214,6 +220,7 @@ class ConfigSanitizer:
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
"cache_info": bool,
|
||||
}
|
||||
CN_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
@@ -230,8 +237,6 @@ class ConfigSanitizer:
|
||||
"min_bucket_reso": int,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"cache_meta": bool,
|
||||
"use_cached_meta": bool,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -366,7 +371,9 @@ class ConfigSanitizer:
|
||||
return self.argparse_config_validator(argparse_namespace)
|
||||
except MultipleInvalid:
|
||||
# XXX: this should be a bug
|
||||
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||
logger.error(
|
||||
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
||||
)
|
||||
raise
|
||||
|
||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||
@@ -551,11 +558,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f'{info}')
|
||||
logger.info(f"{info}")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
@@ -642,13 +649,17 @@ def load_user_config(file: str) -> dict:
|
||||
with open(file, "r") as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
logger.error(
|
||||
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
elif file.name.lower().endswith(".toml"):
|
||||
try:
|
||||
config = toml.load(file)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
logger.error(
|
||||
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
||||
@@ -675,13 +686,13 @@ if __name__ == "__main__":
|
||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||
|
||||
logger.info("[argparse_namespace]")
|
||||
logger.info(f'{vars(argparse_namespace)}')
|
||||
logger.info(f"{vars(argparse_namespace)}")
|
||||
|
||||
user_config = load_user_config(config_args.dataset_config)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[user_config]")
|
||||
logger.info(f'{user_config}')
|
||||
logger.info(f"{user_config}")
|
||||
|
||||
sanitizer = ConfigSanitizer(
|
||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||
@@ -690,10 +701,10 @@ if __name__ == "__main__":
|
||||
|
||||
logger.info("")
|
||||
logger.info("[sanitized_user_config]")
|
||||
logger.info(f'{sanitized_user_config}')
|
||||
logger.info(f"{sanitized_user_config}")
|
||||
|
||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[blueprint]")
|
||||
logger.info(f'{blueprint}')
|
||||
logger.info(f"{blueprint}")
|
||||
|
||||
@@ -410,6 +410,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
is_reg: bool,
|
||||
class_tokens: Optional[str],
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator: str,
|
||||
@@ -458,6 +459,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DreamBoothSubset):
|
||||
@@ -527,6 +529,7 @@ class ControlNetSubset(BaseSubset):
|
||||
image_dir: str,
|
||||
conditioning_data_dir: str,
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -574,6 +577,7 @@ class ControlNetSubset(BaseSubset):
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ControlNetSubset):
|
||||
@@ -1410,6 +1414,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
@@ -1425,8 +1431,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
cache_meta: bool,
|
||||
use_cached_meta: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -1486,25 +1490,36 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
|
||||
sizes = None
|
||||
if use_cached_meta:
|
||||
logger.info(f"using cached metadata: {subset.image_dir}/dataset.txt")
|
||||
# [img_path, caption, resolution]
|
||||
with open(f"{subset.image_dir}/dataset.txt", "r", encoding="utf-8") as f:
|
||||
metas = f.readlines()
|
||||
metas = [x.strip().split("<|##|>") for x in metas]
|
||||
sizes = [tuple(int(res) for res in x[2].split(" ")) for x in metas]
|
||||
|
||||
if use_cached_meta:
|
||||
img_paths = [x[0] for x in metas]
|
||||
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
|
||||
use_cached_info_for_subset = subset.cache_info
|
||||
if use_cached_info_for_subset:
|
||||
logger.info(
|
||||
f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}"
|
||||
)
|
||||
if not os.path.isfile(info_cache_file):
|
||||
logger.warning(
|
||||
f"image info file not found. You can ignore this warning if this is the first time to use this subset"
|
||||
+ " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}"
|
||||
)
|
||||
use_cached_info_for_subset = False
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
|
||||
with open(info_cache_file, "r", encoding="utf-8") as f:
|
||||
metas = json.load(f)
|
||||
img_paths = list(metas.keys())
|
||||
sizes = [meta["resolution"] for meta in metas.values()]
|
||||
|
||||
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = [None]*len(img_paths)
|
||||
sizes = [None] * len(img_paths)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_meta:
|
||||
captions = [x[1] for x in metas]
|
||||
missing_captions = [x[0] for x in metas if x[1] == ""]
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
@@ -1540,19 +1555,17 @@ class DreamBoothDataset(BaseDataset):
|
||||
break
|
||||
logger.warning(missing_caption)
|
||||
|
||||
if cache_meta:
|
||||
logger.info(f"cache metadata for {subset.image_dir}")
|
||||
if sizes is None or sizes[0] is None:
|
||||
sizes = [self.get_image_size(img_path) for img_path in img_paths]
|
||||
# [img_path, caption, resolution]
|
||||
data = [
|
||||
(img_path, caption, " ".join(str(x) for x in size))
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes)
|
||||
]
|
||||
with open(f"{subset.image_dir}/dataset.txt", "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(["<|##|>".join(x) for x in data]))
|
||||
logger.info(f"cache metadata done for {subset.image_dir}")
|
||||
if not use_cached_info_for_subset and subset.cache_info:
|
||||
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
|
||||
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
|
||||
matas = {}
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
matas[img_path] = {"caption": caption, "resolution": list(size)}
|
||||
with open(info_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(matas, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
|
||||
|
||||
# if sizes are not set, image size will be read in make_buckets
|
||||
return img_paths, captions, sizes
|
||||
|
||||
logger.info("prepare images.")
|
||||
@@ -1873,7 +1886,8 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.image_dir,
|
||||
False,
|
||||
None,
|
||||
subset.caption_extension,
|
||||
subset.caption_extension,
|
||||
subset.cache_info,
|
||||
subset.num_repeats,
|
||||
subset.shuffle_caption,
|
||||
subset.caption_separator,
|
||||
@@ -3390,15 +3404,15 @@ def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
):
|
||||
# dataset common
|
||||
parser.add_argument(
|
||||
"--cache_meta", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cached_meta", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_info",
|
||||
action="store_true",
|
||||
help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth"
|
||||
+ " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user