From 227a62e4c4d3c3c5269a244328609ce2da96ebda Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 30 Jun 2023 07:40:22 +0900 Subject: [PATCH] fix to work with dreambooth ds without toml --- sdxl_train.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 66b3c76d..35934c38 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -41,6 +41,7 @@ def train(args): ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -61,18 +62,31 @@ def train(args): ) ) else: - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)