add recursive structure merge dd tags and convert to pathlib

This commit is contained in:
breakcore2
2023-01-26 01:01:38 -08:00
parent 00f1296537
commit 2ce9ad235c

View File

@@ -2,25 +2,33 @@
# (c) 2022 Kohya S. @kohya_ss # (c) 2022 Kohya S. @kohya_ss
import argparse import argparse
import glob
import os
import json import json
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) image_paths = None
train_data_dir_path = Path(args.train_data_dir)
if args.recursive_data_dir:
image_paths = list(train_data_dir_path.rglob('*.jpg')) + \
list(train_data_dir_path.rglob('*.png')) + \
list(train_data_dir_path.rglob('*.webp'))
else:
image_paths = list(train_data_dir_path.glob('*.jpg')) + \
list(train_data_dir_path.glob('*.png')) + \
list(train_data_dir_path.glob('*.webp'))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
metadata = json.load(f)
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") print("new metadata will be created / 新しいメタデータファイルが作成されます")
@@ -28,22 +36,21 @@ def main(args):
print("merge tags to metadata json.") print("merge tags to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
tags_path = os.path.splitext(image_path)[0] + '.txt' tags_path = image_path.with_suffix('.txt')
with open(tags_path, "rt", encoding='utf-8') as f: tags = tags_path.read_text(encoding='utf-8').strip()
tags = f.readlines()[0].strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = image_path if args.full_path else image_path.stem
if image_key not in metadata: if str(image_key) not in metadata:
metadata[image_key] = {} metadata[str(image_key)] = {}
metadata[image_key]['tags'] = tags metadata[str(image_key)]['tags'] = tags
if args.debug: if args.debug:
print(image_key, tags) print(image_key, tags)
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
json.dump(metadata, f, indent=2)
print("done!") print("done!")
@@ -54,6 +61,7 @@ if __name__ == '__main__':
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む") parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true", parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive_data_dir", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags") parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
args = parser.parse_args() args = parser.parse_args()