diff --git a/sdxl_train.py b/sdxl_train.py index 66b3c76d..35934c38 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -41,6 +41,7 @@ def train(args): ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -61,18 +62,31 @@ def train(args): ) ) else: - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)