From 09f575fd4dea389af9fd61508b1854e9ee7a594c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 2 Mar 2023 21:17:25 +0900 Subject: [PATCH] merge image_dir for metadata editor --- train_network.py | 140 ++++++++++++++++++++++++++++------------------- 1 file changed, 85 insertions(+), 55 deletions(-) diff --git a/train_network.py b/train_network.py index 5cd08d49..28c2c769 100644 --- a/train_network.py +++ b/train_network.py @@ -15,12 +15,12 @@ from diffusers import DDPMScheduler import library.train_util as train_util from library.train_util import ( - DreamBoothDataset, + DreamBoothDataset, ) import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) @@ -68,24 +68,25 @@ def train(args): user_config = config_util.load_user_config(args.config_file) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) else: if use_dreambooth_method: print("Use DreamBooth method.") user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] + "datasets": [{ + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) + }] } else: print("Train with captions.") user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, + "datasets": [{ + "subsets": [{ + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + }] }] - }] } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) @@ -99,7 +100,8 @@ def train(args): return if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + assert train_dataset_group.is_latent_cacheable( + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" # acceleratorを準備する print("prepare accelerator") @@ -255,10 +257,11 @@ def train(args): print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - #print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + # TODO refactor metadata creation and move to util metadata = { "ss_session_id": session_id, # random integer indicating which group of epochs the model came from "ss_training_started_at": training_started_at, # unix timestamp @@ -304,48 +307,73 @@ def train(args): # or should also pack nested collections as json? datasets_metadata = [] tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor for dataset in train_dataset_group.datasets: is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) dataset_metadata = { - "is_dreambooth": is_dreambooth_dataset, - "batch_size_per_device": dataset.batch_size, - "num_train_images": dataset.num_train_images, # includes repeating - "num_reg_images": dataset.num_reg_images, - "resolution": (dataset.width, dataset.height), - "enable_bucket": bool(dataset.enable_bucket), - "min_bucket_reso": dataset.min_bucket_reso, - "max_bucket_reso": dataset.max_bucket_reso, - "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, } subsets_metadata = [] for subset in dataset.subsets: subset_metadata = { - "img_count": subset.img_count, - "num_repeats": subset.num_repeats, - "color_aug": bool(subset.color_aug), - "flip_aug": bool(subset.flip_aug), - "random_crop": bool(subset.random_crop), - "shuffle_caption": bool(subset.shuffle_caption), - "keep_tokens": subset.keep_tokens, + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, } + + image_dir_or_metadata_file = None if subset.image_dir: - subset_metadata["image_dir"] = os.path.basename(subset.image_dir) + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir if is_dreambooth_dataset: subset_metadata["class_tokens"] = subset.class_tokens subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset else: - subset_metadata["metadata_file"] = os.path.basename(subset.metadata_file) + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite subsets_metadata.append(subset_metadata) + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count + } + dataset_metadata["subsets"] = subsets_metadata datasets_metadata.append(dataset_metadata) - # merge tag frequency: + # merge tag frequency: for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない @@ -356,9 +384,11 @@ def train(args): metadata["ss_datasets"] = json.dumps(datasets_metadata) metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) else: # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir - assert len(train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + assert len( + train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" dataset = train_dataset_group.datasets[0] @@ -368,33 +398,33 @@ def train(args): for subset in dataset.subsets: info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info info[os.path.basename(subset.image_dir)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count + "n_repeats": subset.num_repeats, + "img_count": subset.img_count } else: for subset in dataset.subsets: dataset_dirs_info[os.path.basename(subset.metadata_file)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count + "n_repeats": subset.num_repeats, + "img_count": subset.img_count } metadata |= { - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_resolution": args.resolution, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_enable_bucket": bool(dataset.enable_bucket), - "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), - "ss_min_bucket_reso": dataset.min_bucket_reso, - "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(dataset.tag_frequency), - "ss_bucket_info": json.dumps(dataset.bucket_info), + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), } # uncomment if another network is added