feat: added 7 new functionalities including recursive

This commit is contained in:
Linaqruf
2023-04-07 16:51:43 +07:00
committed by GitHub
parent b5c60d7d62
commit 07aa000750

View File

@@ -10,6 +10,7 @@ import numpy as np
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import torch import torch
import pathlib
import library.train_util as train_util import library.train_util as train_util
@@ -23,7 +24,6 @@ SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1] CSV_FILE = FILES[-1]
def preprocess_image(image): def preprocess_image(image):
image = np.array(image) image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR image = image[:, :, ::-1] # RGB->BGR
@@ -42,7 +42,6 @@ def preprocess_image(image):
image = image.astype(np.float32) image = image.astype(np.float32)
return image return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset): class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths): def __init__(self, image_paths):
self.images = image_paths self.images = image_paths
@@ -51,7 +50,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
return len(self.images) return len(self.images)
def __getitem__(self, idx): def __getitem__(self, idx):
img_path = self.images[idx] img_path = str(self.images[idx])
try: try:
image = Image.open(img_path).convert("RGB") image = Image.open(img_path).convert("RGB")
@@ -63,7 +62,6 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
return (tensor, img_path) return (tensor, img_path)
def collate_fn_remove_corrupted(batch): def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the """Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs. dataloader. It expects that the dataloader returns 'None' when that occurs.
@@ -73,7 +71,6 @@ def collate_fn_remove_corrupted(batch):
batch = list(filter(lambda x: x is not None, batch)) batch = list(filter(lambda x: x is not None, batch))
return batch return batch
def main(args): def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時 # depreacatedの警告が出るけどなくなったらその時
@@ -89,14 +86,11 @@ def main(args):
print("using existing wd14 tagger model") print("using existing wd14 tagger model")
# 画像を読み込む # 画像を読み込む
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print("loading model and labels")
model = load_model(args.model_dir) model = load_model(args.model_dir)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ # 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f) reader = csv.reader(f)
l = [row for row in reader] l = [row for row in reader]
@@ -104,9 +98,19 @@ def main(args):
rows = l[1:] rows = l[1:]
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ general_tags = [row[1] for row in rows[1:] if row[2] == '0']
character_tags = [row[1] for row in rows[1:] if row[2] == '4']
# 画像を読み込む
train_data_dir = pathlib.Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive)
print(f"found {len(image_paths)} images.")
tag_freq = {}
undesired_tags = set(args.undesired_tags.split(','))
# 推論する
def run_batch(path_imgs): def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs]) imgs = np.array([im for _, im in path_imgs])
@@ -122,18 +126,36 @@ def main(args):
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する # それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold # Everything else is tags: pick any where prediction confidence > threshold
tag_text = "" combined_tags = []
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで general_tag_text = ""
if p >= args.thresh and i < len(tags): character_tag_text = ""
tag_text += ", " + tags[i] for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i].replace('_', ' ') if args.remove_underscore else general_tags[i]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)].replace('_', ' ') if args.remove_underscore else character_tags[i - len(general_tags)]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
combined_tags.append(tag_name)
if len(tag_text) > 0: if len(general_tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す general_tag_text = general_tag_text[2:]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
tag_text = ', '.join(combined_tags)
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(tag_text + '\n') f.write(tag_text + '\n')
if args.debug: if args.debug:
print(image_path, tag_text) print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
@@ -164,38 +186,45 @@ def main(args):
b_imgs.append((image_path, image)) b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size: if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs) run_batch(b_imgs)
b_imgs.clear() b_imgs.clear()
if len(b_imgs) > 0: if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs) run_batch(b_imgs)
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
print("\nTag frequencies:")
for tag, freq in sorted_tags:
print(f"{tag}: {freq}")
print("done!") print("done!")
def setup_parser() -> argparse.ArgumentParser: if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
parser.add_argument("--force_download", action='store_true', parser.add_argument("--force_download", action='store_true',
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None, parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化") help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None, parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--general_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for general category")
parser.add_argument("--character_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for character category")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
parser.add_argument("--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output")
return parser parser.add_argument('--frequency_tags', action='store_true', help='Show frequency of tags for images')
if __name__ == '__main__':
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()