From 07aa000750f83fa9faae083f57457ae383b2e3f7 Mon Sep 17 00:00:00 2001 From: Linaqruf <50163983+Linaqruf@users.noreply.github.com> Date: Fri, 7 Apr 2023 16:51:43 +0700 Subject: [PATCH] feat: added 7 new functionalities including recursive --- finetune/tag_images_by_wd14_tagger.py | 319 ++++++++++++++------------ 1 file changed, 174 insertions(+), 145 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 2286115e..ac00baec 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -10,6 +10,7 @@ import numpy as np from tensorflow.keras.models import load_model from huggingface_hub import hf_hub_download import torch +import pathlib import library.train_util as train_util @@ -23,184 +24,212 @@ SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] - def preprocess_image(image): - image = np.array(image) - image = image[:, :, ::-1] # RGB->BGR + image = np.array(image) + image = image[:, :, ::-1] # RGB->BGR - # pad to square - size = max(image.shape[0:2]) - pad_x = size - image.shape[1] - pad_y = size - image.shape[0] - pad_l = pad_x // 2 - pad_t = pad_y // 2 - image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + # pad to square + size = max(image.shape[0:2]) + pad_x = size - image.shape[1] + pad_y = size - image.shape[0] + pad_l = pad_x // 2 + pad_t = pad_y // 2 + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) - - image = image.astype(np.float32) - return image + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + image = image.astype(np.float32) + return image class ImageLoadingPrepDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] - - try: - image = Image.open(img_path).convert("RGB") - image = preprocess_image(image) - tensor = torch.tensor(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None - - return (tensor, img_path) + def __getitem__(self, idx): + img_path = str(self.images[idx]) + try: + image = Image.open(img_path).convert("RGB") + image = preprocess_image(image) + tensor = torch.tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + return (tensor, img_path) + def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch - + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする - # depreacatedの警告が出るけどなくなったらその時 - # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - for file in FILES: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( - args.model_dir, SUB_DIR), force_download=True, force_filename=file) - else: - print("using existing wd14 tagger model") + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする + # depreacatedの警告が出るけどなくなったらその時 + # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(args.model_dir) or args.force_download: + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + for file in FILES: + hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + for file in SUB_DIR_FILES: + hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( + args.model_dir, SUB_DIR), force_download=True, force_filename=file) + else: + print("using existing wd14 tagger model") - # 画像を読み込む - image_paths = train_util.glob_images(args.train_data_dir) - print(f"found {len(image_paths)} images.") + # 画像を読み込む + model = load_model(args.model_dir) - print("loading model and labels") - model = load_model(args.model_dir) + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") + # 依存ライブラリを増やしたくないので自力で読むよ - # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: - reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] - assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" + with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + l = [row for row in reader] + header = l[0] # tag_id,name,category,count + rows = l[1:] + assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" - tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ + general_tags = [row[1] for row in rows[1:] if row[2] == '0'] + character_tags = [row[1] for row in rows[1:] if row[2] == '4'] - # 推論する - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) + # 画像を読み込む + + train_data_dir = pathlib.Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive) + print(f"found {len(image_paths)} images.") - probs = model(imgs, training=False) - probs = probs.numpy() + tag_freq = {} - for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + undesired_tags = set(args.undesired_tags.split(',')) - # それ以降はタグなのでconfidenceがthresholdより高いものを追加する - # Everything else is tags: pick any where prediction confidence > threshold - tag_text = "" - for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで - if p >= args.thresh and i < len(tags): - tag_text += ", " + tags[i] + def run_batch(path_imgs): + imgs = np.array([im for _, im in path_imgs]) - if len(tag_text) > 0: - tag_text = tag_text[2:] # 最初の ", " を消す + probs = model(imgs, training=False) + probs = probs.numpy() - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(tag_text + '\n') - if args.debug: - print(image_path, tag_text) + for (image_path, _), prob in zip(path_imgs, probs): + # 最初の4つはratingなので無視する + # # First 4 labels are actually ratings: pick one with argmax + # ratings_names = label_names[:4] + # rating_index = ratings_names["probs"].argmax() + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingPrepDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] + # それ以降はタグなのでconfidenceがthresholdより高いものを追加する + # Everything else is tags: pick any where prediction confidence > threshold + combined_tags = [] + general_tag_text = "" + character_tag_text = "" + for i, p in enumerate(prob[4:]): + if i < len(general_tags) and p >= args.general_threshold: + tag_name = general_tags[i].replace('_', ' ') if args.remove_underscore else general_tags[i] + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += ", " + tag_name + combined_tags.append(tag_name) + elif i >= len(general_tags) and p >= args.character_threshold: + tag_name = character_tags[i - len(general_tags)].replace('_', ' ') if args.remove_underscore else character_tags[i - len(general_tags)] + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += ", " + tag_name + combined_tags.append(tag_name) - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue + if len(general_tag_text) > 0: + general_tag_text = general_tag_text[2:] - image, image_path = data - if image is not None: - image = image.detach().numpy() - else: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - image = preprocess_image(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - b_imgs.append((image_path, image)) + if len(character_tag_text) > 0: + character_tag_text = character_tag_text[2:] - if len(b_imgs) >= args.batch_size: + tag_text = ', '.join(combined_tags) + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(tag_text + '\n') + if args.debug: + print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingPrepDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is not None: + image = image.detach().numpy() + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + image = preprocess_image(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + b_imgs.append((image_path, image)) + + if len(b_imgs) >= args.batch_size: + b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) + if args.frequency_tags: + sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) + print("\nTag frequencies:") + for tag, freq in sorted_tags: + print(f"{tag}: {freq}") - print("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, - help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") - parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", - help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") - parser.add_argument("--force_download", action='store_true', - help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser + print("done!") if __name__ == '__main__': - parser = setup_parser() + parser = argparse.ArgumentParser() + parser.add_argument("--train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") + parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", + help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") + parser.add_argument("--force_download", action='store_true', + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--general_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for general category") + parser.add_argument("--character_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for character category") + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") + parser.add_argument("--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags") + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output") + parser.add_argument('--frequency_tags', action='store_true', help='Show frequency of tags for images') - args = parser.parse_args() + args = parser.parse_args() - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention - main(args) + main(args)