diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 896c99ae..16267cd8 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -2,25 +2,34 @@ # (c) 2022 Kohya S. @kohya_ss import argparse -import glob -import os import json +from pathlib import Path from tqdm import tqdm def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + image_paths = None + train_data_dir_path = Path(args.train_data_dir) + if args.recursive: + image_paths = list(train_data_dir_path.rglob('*.jpg')) + \ + list(train_data_dir_path.rglob('*.jpeg')) + \ + list(train_data_dir_path.rglob('*.png')) + \ + list(train_data_dir_path.rglob('*.webp')) + else: + image_paths = list(train_data_dir_path.glob('*.jpg')) + \ + list(train_data_dir_path.glob('*.jpeg')) + \ + list(train_data_dir_path.glob('*.png')) + \ + list(train_data_dir_path.glob('*.webp')) + print(f"found {len(image_paths)} images.") - if args.in_json is None and os.path.isfile(args.out_json): + if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) + metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") else: print("new metadata will be created / 新しいメタデータファイルが作成されます") @@ -28,22 +37,21 @@ def main(args): print("merge tags to metadata json.") for image_path in tqdm(image_paths): - tags_path = os.path.splitext(image_path)[0] + '.txt' - with open(tags_path, "rt", encoding='utf-8') as f: - tags = f.readlines()[0].strip() + tags_path = image_path.with_suffix('.txt') + tags = tags_path.read_text(encoding='utf-8').strip() - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - if image_key not in metadata: - metadata[image_key] = {} + image_key = image_path if args.full_path else image_path.stem + if str(image_key) not in metadata: + metadata[str(image_key)] = {} - metadata[image_key]['tags'] = tags + metadata[str(image_key)]['tags'] = tags if args.debug: print(image_key, tags) # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding='utf-8') as f: - json.dump(metadata, f, indent=2) + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + print("done!") @@ -54,6 +62,7 @@ if __name__ == '__main__': parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") parser.add_argument("--full_path", action="store_true", help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") + parser.add_argument("--recursive", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") parser.add_argument("--debug", action="store_true", help="debug mode, print tags") args = parser.parse_args() diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 60a9a890..d7166114 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -36,8 +36,11 @@ def main(args): args.model_dir, SUB_DIR), force_download=True, force_filename=file) # 画像を読み込む - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.bmp")) print(f"found {len(image_paths)} images.") print("loading model and labels") diff --git a/library/train_util.py b/library/train_util.py index 85b58d7e..0946c31d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -87,6 +87,7 @@ class BaseDataset(torch.utils.data.Dataset): self.enable_bucket = False self.min_bucket_reso = None self.max_bucket_reso = None + self.tag_frequency = {} self.bucket_info = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -545,6 +546,15 @@ class DreamBoothDataset(BaseDataset): cap_for_img = read_caption(img_path) captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {}) + self.tag_frequency[os.path.basename(dir)] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + if tag and not tag.isspace(): + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + return n_repeats, img_paths, captions print("prepare train images.") diff --git a/train_network.py b/train_network.py index 37a10f65..aebc4a40 100644 --- a/train_network.py +++ b/train_network.py @@ -335,6 +335,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment # will not be updated after training }