Files
Kohya-ss-sd-scripts/cache_dataset_meta.py
Kohaku-Blueleaf ae97c8bfd1 [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-03-24 15:40:18 +09:00

104 lines
3.5 KiB
Python

import argparse
import random
from accelerate.utils import set_seed
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def make_dataset(args):
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(
ConfigSanitizer(True, True, False, True)
)
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=None)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
blueprint.dataset_group
)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer=None)
return train_dataset_group
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args, unknown = parser.parse_known_args()
args = train_util.read_config_from_file(args, parser)
if args.max_token_length is None:
args.max_token_length = 75
args.cache_meta = True
dataset_group = make_dataset(args)