From 77ad20bc8f39e3595ccb3f733e1a6bc85d4c41d4 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 30 Aug 2025 15:47:47 +0900 Subject: [PATCH] feat: support another tagger model --- docs/wd14_tagger_README-en.md | 6 +- docs/wd14_tagger_README-ja.md | 6 +- finetune/tag_images_by_wd14_tagger.py | 426 +++++++++++++++++++------- 3 files changed, 324 insertions(+), 114 deletions(-) diff --git a/docs/wd14_tagger_README-en.md b/docs/wd14_tagger_README-en.md index 34f44882..48a4e9df 100644 --- a/docs/wd14_tagger_README-en.md +++ b/docs/wd14_tagger_README-en.md @@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github. Using onnx for inference is recommended. Please install onnx with the following command: ```powershell -pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 +pip install onnx onnxruntime-gpu ``` +See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details. + The model weights will be automatically downloaded from Hugging Face. # Usage @@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge # Options +All options can be checked with `python tag_images_by_wd14_tagger.py --help`. + ## General Options - `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately. diff --git a/docs/wd14_tagger_README-ja.md b/docs/wd14_tagger_README-ja.md index 58e9ede9..49d14636 100644 --- a/docs/wd14_tagger_README-ja.md +++ b/docs/wd14_tagger_README-ja.md @@ -5,9 +5,11 @@ onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。 ```powershell -pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 +pip install onnx onnxruntime-gpu ``` +詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。 + モデルの重みはHugging Faceから自動的にダウンロードしてきます。 # 使い方 @@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge # オプション +全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。 + ## 一般オプション - `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。 diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 07a6510e..17a75bd7 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,12 +1,13 @@ import argparse import csv +import json import os from pathlib import Path import cv2 import numpy as np import torch -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, errors from PIL import Image from tqdm import tqdm @@ -29,8 +30,22 @@ SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] +TAG_JSON_FILE = "tag_mapping.json" + + +def preprocess_image(image: Image.Image) -> np.ndarray: + # If image has transparency, convert to RGBA. If not, convert to RGB + if image.mode in ("RGBA", "LA") or "transparency" in image.info: + image = image.convert("RGBA") + elif image.mode != "RGB": + image = image.convert("RGB") + + # If image is RGBA, combine with white background + if image.mode == "RGBA": + background = Image.new("RGB", image.size, (255, 255, 255)) + background.paste(image, mask=image.split()[3]) # Use alpha channel as mask + image = background -def preprocess_image(image): image = np.array(image) image = image[:, :, ::-1] # RGB->BGR @@ -59,7 +74,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset): img_path = str(self.images[idx]) try: - image = Image.open(img_path).convert("RGB") + image = Image.open(img_path) image = preprocess_image(image) # tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・) except Exception as e: @@ -81,35 +96,64 @@ def collate_fn_remove_corrupted(batch): def main(args): # model location is model_dir + repo_id - # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash - model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir" + # so we split it to "namespace/reponame" and "subdir" + tokens = args.repo_id.split("/") + + if len(tokens) > 2: + repo_id = "/".join(tokens[:2]) + subdir = "/".join(tokens[2:]) + model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir) + onnx_model_name = "model_optimized.onnx" + default_format = False + else: + repo_id = args.repo_id + subdir = None + model_location = os.path.join(args.model_dir, repo_id.replace("/", "_")) + onnx_model_name = "model.onnx" + default_format = True - # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする - # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(model_location) or args.force_download: os.makedirs(args.model_dir, exist_ok=True) logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - files = FILES - if args.onnx: - files = ["selected_tags.csv"] - files += FILES_ONNX - else: - for file in SUB_DIR_FILES: + + if subdir is None: + # SmilingWolf structure + files = FILES + if args.onnx: + files = ["selected_tags.csv"] + files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + repo_id=args.repo_id, + filename=file, + subfolder=SUB_DIR, + local_dir=os.path.join(model_location, SUB_DIR), + force_download=True, + ) + + for file in files: hf_hub_download( repo_id=args.repo_id, filename=file, - subfolder=SUB_DIR, - local_dir=os.path.join(model_location, SUB_DIR), + local_dir=model_location, + force_download=True, + ) + else: + # another structure + files = [onnx_model_name, "tag_mapping.json"] + + for file in files: + hf_hub_download( + repo_id=repo_id, + filename=file, + subfolder=subdir, + local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified force_download=True, ) - for file in files: - hf_hub_download( - repo_id=args.repo_id, - filename=file, - local_dir=model_location, - force_download=True, - ) else: logger.info("using existing wd14 tagger model") @@ -118,7 +162,7 @@ def main(args): import onnx import onnxruntime as ort - onnx_path = f"{model_location}/model.onnx" + onnx_path = os.path.join(model_location, onnx_model_name) logger.info("Running wd14 tagger with onnx") logger.info(f"loading onnx model: {onnx_path}") @@ -150,39 +194,30 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, providers=(["OpenVINOExecutionProvider"]), - provider_options=[{'device_type' : "GPU", "precision": "FP32"}], + provider_options=[{"device_type": "GPU", "precision": "FP32"}], ) else: - ort_sess = ort.InferenceSession( - onnx_path, - providers=( - ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else - ["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else - ["CPUExecutionProvider"] - ), + providers = ( + ["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ( + ["ROCMExecutionProvider"] + if "ROCMExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"] + ) ) + logger.info(f"Using onnxruntime providers: {providers}") + ort_sess = ort.InferenceSession(onnx_path, providers=providers) else: from tensorflow.keras.models import load_model model = load_model(f"{model_location}") + # We read the CSV file manually to avoid adding dependencies. # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: - reader = csv.reader(f) - line = [row for row in reader] - header = line[0] # tag_id,name,category,count - rows = line[1:] - assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - - rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] - general_tags = [row[1] for row in rows[0:] if row[2] == "0"] - character_tags = [row[1] for row in rows[0:] if row[2] == "4"] - - # preprocess tags in advance - if args.character_tag_expand: - for i, tag in enumerate(character_tags): + def expand_character_tags(char_tags): + for i, tag in enumerate(char_tags): if tag.endswith(")"): # chara_name_(series) -> chara_name, series # chara_name_(costume)_(series) -> chara_name_(costume), series @@ -191,30 +226,86 @@ def main(args): if character_tag.endswith("_"): character_tag = character_tag[:-1] series_tag = tags[-1].replace(")", "") - character_tags[i] = character_tag + args.caption_separator + series_tag + char_tags[i] = character_tag + args.caption_separator + series_tag - if args.remove_underscore: - rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags] - general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags] - character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags] + def remove_underscore(tags): + return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags] - if args.tag_replacement is not None: - # escape , and ; in tag_replacement: wd14 tag names may contain , and ; - escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####") + def process_tag_replacement(tags: list[str], tag_replacements_arg: str): + # escape , and ; in tag_replacement: wd14 tag names may contain , and ;, + # so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means + # `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD` + escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####") tag_replacements = escaped_tag_replacements.split(";") - for tag_replacement in tag_replacements: - tags = tag_replacement.split(",") # source, target - assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" + + for tag_replacements_arg in tag_replacements: + tags = tag_replacements_arg.split(",") # source, target + assert ( + len(tags) == 2 + ), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags] logger.info(f"replacing tag: {source} -> {target}") - if source in general_tags: - general_tags[general_tags.index(source)] = target - elif source in character_tags: - character_tags[character_tags.index(source)] = target - elif source in rating_tags: - rating_tags[rating_tags.index(source)] = target + if source in tags: + tags[tags.index(source)] = target + + if default_format: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + line = [row for row in reader] + header = line[0] # tag_id,name,category,count + rows = line[1:] + assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" + + rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] + general_tags = [row[1] for row in rows[0:] if row[2] == "0"] + character_tags = [row[1] for row in rows[0:] if row[2] == "4"] + + if args.character_tag_expand: + expand_character_tags(character_tags) + if args.remove_underscore: + rating_tags = remove_underscore(rating_tags) + character_tags = remove_underscore(character_tags) + general_tags = remove_underscore(general_tags) + if args.tag_replacement is not None: + process_tag_replacement(rating_tags, args.tag_replacement) + process_tag_replacement(general_tags, args.tag_replacement) + process_tag_replacement(character_tags, args.tag_replacement) + else: + with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f: + tag_mapping = json.load(f) + + rating_tags = [] + general_tags = [] + character_tags = [] + + tag_id_to_tag_mapping = {} + tag_id_to_category_mapping = {} + for tag_id, tag_info in tag_mapping.items(): + tag = tag_info["tag"] + category = tag_info["category"] + assert category in [ + "Rating", + "General", + "Character", + "Copyright", + "Meta", + "Model", + "Quality", + ], f"unexpected category: {category}" + + if args.remove_underscore: + tag = remove_underscore([tag])[0] + if args.tag_replacement is not None: + tag = process_tag_replacement([tag], args.tag_replacement)[0] + if category == "Character" and args.character_tag_expand: + tag_list = [tag] + expand_character_tags(tag_list) + tag = tag_list[0] + + tag_id_to_tag_mapping[int(tag_id)] = tag + tag_id_to_category_mapping[int(tag_id)] = category # 画像を読み込む train_data_dir_path = Path(args.train_data_dir) @@ -238,6 +329,9 @@ def main(args): if args.onnx: # if len(imgs) < args.batch_size: # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + if not default_format: + imgs = imgs.transpose(0, 3, 1, 2) # to NCHW + imgs = imgs / 127.5 - 1.0 probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -249,42 +343,112 @@ def main(args): rating_tag_text = "" character_tag_text = "" general_tag_text = "" + other_tag_text = "" - # 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する - # First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold - for i, p in enumerate(prob[4:]): - if i < len(general_tags) and p >= args.general_threshold: - tag_name = general_tags[i] + if default_format: + # 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する + # First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold + for i, p in enumerate(prob[4:]): + if i < len(general_tags) and p >= args.general_threshold: + tag_name = general_tags[i] - if tag_name not in undesired_tags: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += caption_separator + tag_name - combined_tags.append(tag_name) - elif i >= len(general_tags) and p >= args.character_threshold: - tag_name = character_tags[i - len(general_tags)] - - if tag_name not in undesired_tags: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += caption_separator + tag_name - if args.character_tags_first: # insert to the beginning - combined_tags.insert(0, tag_name) - else: + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) + elif i >= len(general_tags) and p >= args.character_threshold: + tag_name = character_tags[i - len(general_tags)] - # 最初の4つはratingなのでargmaxで選ぶ - # First 4 labels are actually ratings: pick one with argmax - if args.use_rating_tags or args.use_rating_tags_as_last_tag: - ratings_probs = prob[:4] - rating_index = ratings_probs.argmax() - found_rating = rating_tags[rating_index] + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # insert to the beginning + combined_tags.insert(0, tag_name) + else: + combined_tags.append(tag_name) - if found_rating not in undesired_tags: - tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 - rating_tag_text = found_rating - if args.use_rating_tags: - combined_tags.insert(0, found_rating) # insert to the beginning - else: - combined_tags.append(found_rating) + # 最初の4つはratingなのでargmaxで選ぶ + # First 4 labels are actually ratings: pick one with argmax + if args.use_rating_tags or args.use_rating_tags_as_last_tag: + ratings_probs = prob[:4] + rating_index = ratings_probs.argmax() + found_rating = rating_tags[rating_index] + + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + if args.use_rating_tags: + combined_tags.insert(0, found_rating) # insert to the beginning + else: + combined_tags.append(found_rating) + else: + # apply sigmoid to probabilities + prob = 1 / (1 + np.exp(-prob)) + + rating_max_prob = -1 + rating_tag = None + quality_max_prob = -1 + quality_tag = None + character_tags = [] + for i, p in enumerate(prob): + if i in tag_id_to_tag_mapping and p >= args.thresh: + tag_name = tag_id_to_tag_mapping[i] + category = tag_id_to_category_mapping[i] + if tag_name in undesired_tags: + continue + + if category == "Rating": + if p > rating_max_prob: + rating_max_prob = p + rating_tag = tag_name + rating_tag_text = tag_name + continue + elif category == "Quality": + if p > quality_max_prob: + quality_max_prob = p + quality_tag = tag_name + if args.use_quality_tags or args.use_quality_tags_as_last_tag: + other_tag_text += caption_separator + tag_name + continue + + if category == "General" and p >= args.general_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += caption_separator + tag_name + combined_tags.append((tag_name, p)) + elif category == "Character" and p >= args.character_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # we separate character tags + character_tags.append((tag_name, p)) + else: + combined_tags.append((tag_name, p)) + elif ( + (category == "Copyright" and p >= args.copyright_threshold) + or (category == "Meta" and p >= args.meta_threshold) + or (category == "Model" and p >= args.model_threshold) + ): + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + other_tag_text += f"{caption_separator}{tag_name} ({category})" + combined_tags.append((tag_name, p)) + + # sort by probability + combined_tags.sort(key=lambda x: x[1], reverse=True) + if character_tags: + print(character_tags) + character_tags.sort(key=lambda x: x[1], reverse=True) + combined_tags = character_tags + combined_tags + combined_tags = [t[0] for t in combined_tags] # remove probability + + if quality_tag is not None: + if args.use_quality_tags_as_last_tag: + combined_tags.append(quality_tag) + elif args.use_quality_tags: + combined_tags.insert(0, quality_tag) + if rating_tag is not None: + if args.use_rating_tags_as_last_tag: + combined_tags.append(rating_tag) + elif args.use_rating_tags: + combined_tags.insert(0, rating_tag) # 一番最初に置くタグを指定する # Always put some tags at the beginning @@ -299,6 +463,8 @@ def main(args): general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: character_tag_text = character_tag_text[len(caption_separator) :] + if len(other_tag_text) > 0: + other_tag_text = other_tag_text[len(caption_separator) :] caption_file = os.path.splitext(image_path)[0] + args.caption_extension @@ -328,6 +494,8 @@ def main(args): logger.info(f"\tRating tags: {rating_tag_text}") logger.info(f"\tCharacter tags: {character_tag_text}") logger.info(f"\tGeneral tags: {general_tag_text}") + if other_tag_text: + logger.info(f"\tOther tags: {other_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -353,8 +521,6 @@ def main(args): if image is None: try: image = Image.open(image_path) - if image.mode != "RGB": - image = image.convert("RGB") image = preprocess_image(image) except Exception as e: logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") @@ -381,9 +547,7 @@ def main(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument( - "train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ" - ) + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( "--repo_id", type=str, @@ -401,9 +565,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) - parser.add_argument( - "--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ" - ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", type=int, @@ -432,7 +594,29 @@ def setup_parser() -> argparse.ArgumentParser: "--character_threshold", type=float, default=None, - help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", + help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags" + " / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる", + ) + parser.add_argument( + "--meta_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags" + " / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる", + ) + parser.add_argument( + "--model_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags" + " / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる", + ) + parser.add_argument( + "--copyright_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags" + " / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる", ) parser.add_argument( "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" @@ -442,9 +626,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", ) - parser.add_argument( - "--debug", action="store_true", help="debug mode" - ) + parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument( "--undesired_tags", type=str, @@ -454,20 +636,34 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する" ) - parser.add_argument( - "--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する" - ) + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument( "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" ) parser.add_argument( - "--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", + "--use_rating_tags", + action="store_true", + help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", ) parser.add_argument( - "--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", + "--use_rating_tags_as_last_tag", + action="store_true", + help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", ) parser.add_argument( - "--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", + "--use_quality_tags", + action="store_true", + help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する", + ) + parser.add_argument( + "--use_quality_tags_as_last_tag", + action="store_true", + help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する", + ) + parser.add_argument( + "--character_tags_first", + action="store_true", + help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", ) parser.add_argument( "--always_first_tags", @@ -512,5 +708,11 @@ if __name__ == "__main__": args.general_threshold = args.thresh if args.character_threshold is None: args.character_threshold = args.thresh + if args.meta_threshold is None: + args.meta_threshold = args.thresh + if args.model_threshold is None: + args.model_threshold = args.thresh + if args.copyright_threshold is None: + args.copyright_threshold = args.thresh main(args)