feat: support another tagger model

This commit is contained in:
Kohya S
2025-08-30 15:47:47 +09:00
parent 1e623e7c31
commit 77ad20bc8f
3 changed files with 324 additions and 114 deletions

View File

@@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github.
Using onnx for inference is recommended. Please install onnx with the following command:
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
pip install onnx onnxruntime-gpu
```
See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details.
The model weights will be automatically downloaded from Hugging Face.
# Usage
@@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
# Options
All options can be checked with `python tag_images_by_wd14_tagger.py --help`.
## General Options
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.

View File

@@ -5,9 +5,11 @@
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
pip install onnx onnxruntime-gpu
```
詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
# 使い方
@@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge
# オプション
全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。
## 一般オプション
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。

View File

@@ -1,12 +1,13 @@
import argparse
import csv
import json
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, errors
from PIL import Image
from tqdm import tqdm
@@ -29,8 +30,22 @@ SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
TAG_JSON_FILE = "tag_mapping.json"
def preprocess_image(image: Image.Image) -> np.ndarray:
# If image has transparency, convert to RGBA. If not, convert to RGB
if image.mode in ("RGBA", "LA") or "transparency" in image.info:
image = image.convert("RGBA")
elif image.mode != "RGB":
image = image.convert("RGB")
# If image is RGBA, combine with white background
if image.mode == "RGBA":
background = Image.new("RGB", image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3]) # Use alpha channel as mask
image = background
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
@@ -59,7 +74,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
img_path = str(self.images[idx])
try:
image = Image.open(img_path).convert("RGB")
image = Image.open(img_path)
image = preprocess_image(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
@@ -81,35 +96,64 @@ def collate_fn_remove_corrupted(batch):
def main(args):
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
# given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir"
# so we split it to "namespace/reponame" and "subdir"
tokens = args.repo_id.split("/")
if len(tokens) > 2:
repo_id = "/".join(tokens[:2])
subdir = "/".join(tokens[2:])
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir)
onnx_model_name = "model_optimized.onnx"
default_format = False
else:
repo_id = args.repo_id
subdir = None
model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"))
onnx_model_name = "model.onnx"
default_format = True
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(model_location) or args.force_download:
os.makedirs(args.model_dir, exist_ok=True)
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
if subdir is None:
# SmilingWolf structure
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
local_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
)
for file in files:
hf_hub_download(
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
local_dir=os.path.join(model_location, SUB_DIR),
local_dir=model_location,
force_download=True,
)
else:
# another structure
files = [onnx_model_name, "tag_mapping.json"]
for file in files:
hf_hub_download(
repo_id=repo_id,
filename=file,
subfolder=subdir,
local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified
force_download=True,
)
for file in files:
hf_hub_download(
repo_id=args.repo_id,
filename=file,
local_dir=model_location,
force_download=True,
)
else:
logger.info("using existing wd14 tagger model")
@@ -118,7 +162,7 @@ def main(args):
import onnx
import onnxruntime as ort
onnx_path = f"{model_location}/model.onnx"
onnx_path = os.path.join(model_location, onnx_model_name)
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
@@ -150,39 +194,30 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
provider_options=[{"device_type": "GPU", "precision": "FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
providers = (
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
)
logger.info(f"Using onnxruntime providers: {providers}")
ort_sess = ort.InferenceSession(onnx_path, providers=providers)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{model_location}")
# We read the CSV file manually to avoid adding dependencies.
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# preprocess tags in advance
if args.character_tag_expand:
for i, tag in enumerate(character_tags):
def expand_character_tags(char_tags):
for i, tag in enumerate(char_tags):
if tag.endswith(")"):
# chara_name_(series) -> chara_name, series
# chara_name_(costume)_(series) -> chara_name_(costume), series
@@ -191,30 +226,86 @@ def main(args):
if character_tag.endswith("_"):
character_tag = character_tag[:-1]
series_tag = tags[-1].replace(")", "")
character_tags[i] = character_tag + args.caption_separator + series_tag
char_tags[i] = character_tag + args.caption_separator + series_tag
if args.remove_underscore:
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
def remove_underscore(tags):
return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags]
if args.tag_replacement is not None:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
def process_tag_replacement(tags: list[str], tag_replacements_arg: str):
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;,
# so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means
# `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD`
escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####")
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
for tag_replacements_arg in tag_replacements:
tags = tag_replacements_arg.split(",") # source, target
assert (
len(tags) == 2
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
if source in general_tags:
general_tags[general_tags.index(source)] = target
elif source in character_tags:
character_tags[character_tags.index(source)] = target
elif source in rating_tags:
rating_tags[rating_tags.index(source)] = target
if source in tags:
tags[tags.index(source)] = target
if default_format:
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
if args.character_tag_expand:
expand_character_tags(character_tags)
if args.remove_underscore:
rating_tags = remove_underscore(rating_tags)
character_tags = remove_underscore(character_tags)
general_tags = remove_underscore(general_tags)
if args.tag_replacement is not None:
process_tag_replacement(rating_tags, args.tag_replacement)
process_tag_replacement(general_tags, args.tag_replacement)
process_tag_replacement(character_tags, args.tag_replacement)
else:
with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f:
tag_mapping = json.load(f)
rating_tags = []
general_tags = []
character_tags = []
tag_id_to_tag_mapping = {}
tag_id_to_category_mapping = {}
for tag_id, tag_info in tag_mapping.items():
tag = tag_info["tag"]
category = tag_info["category"]
assert category in [
"Rating",
"General",
"Character",
"Copyright",
"Meta",
"Model",
"Quality",
], f"unexpected category: {category}"
if args.remove_underscore:
tag = remove_underscore([tag])[0]
if args.tag_replacement is not None:
tag = process_tag_replacement([tag], args.tag_replacement)[0]
if category == "Character" and args.character_tag_expand:
tag_list = [tag]
expand_character_tags(tag_list)
tag = tag_list[0]
tag_id_to_tag_mapping[int(tag_id)] = tag
tag_id_to_category_mapping[int(tag_id)] = category
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
@@ -238,6 +329,9 @@ def main(args):
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
if not default_format:
imgs = imgs.transpose(0, 3, 1, 2) # to NCHW
imgs = imgs / 127.5 - 1.0
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
@@ -249,42 +343,112 @@ def main(args):
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
other_tag_text = ""
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if default_format:
# 最初の4つ以降はタグなのでconfidencethreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
else:
# apply sigmoid to probabilities
prob = 1 / (1 + np.exp(-prob))
rating_max_prob = -1
rating_tag = None
quality_max_prob = -1
quality_tag = None
character_tags = []
for i, p in enumerate(prob):
if i in tag_id_to_tag_mapping and p >= args.thresh:
tag_name = tag_id_to_tag_mapping[i]
category = tag_id_to_category_mapping[i]
if tag_name in undesired_tags:
continue
if category == "Rating":
if p > rating_max_prob:
rating_max_prob = p
rating_tag = tag_name
rating_tag_text = tag_name
continue
elif category == "Quality":
if p > quality_max_prob:
quality_max_prob = p
quality_tag = tag_name
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
other_tag_text += caption_separator + tag_name
continue
if category == "General" and p >= args.general_threshold:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append((tag_name, p))
elif category == "Character" and p >= args.character_threshold:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # we separate character tags
character_tags.append((tag_name, p))
else:
combined_tags.append((tag_name, p))
elif (
(category == "Copyright" and p >= args.copyright_threshold)
or (category == "Meta" and p >= args.meta_threshold)
or (category == "Model" and p >= args.model_threshold)
):
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
other_tag_text += f"{caption_separator}{tag_name} ({category})"
combined_tags.append((tag_name, p))
# sort by probability
combined_tags.sort(key=lambda x: x[1], reverse=True)
if character_tags:
print(character_tags)
character_tags.sort(key=lambda x: x[1], reverse=True)
combined_tags = character_tags + combined_tags
combined_tags = [t[0] for t in combined_tags] # remove probability
if quality_tag is not None:
if args.use_quality_tags_as_last_tag:
combined_tags.append(quality_tag)
elif args.use_quality_tags:
combined_tags.insert(0, quality_tag)
if rating_tag is not None:
if args.use_rating_tags_as_last_tag:
combined_tags.append(rating_tag)
elif args.use_rating_tags:
combined_tags.insert(0, rating_tag)
# 一番最初に置くタグを指定する
# Always put some tags at the beginning
@@ -299,6 +463,8 @@ def main(args):
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[len(caption_separator) :]
if len(other_tag_text) > 0:
other_tag_text = other_tag_text[len(caption_separator) :]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
@@ -328,6 +494,8 @@ def main(args):
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if other_tag_text:
logger.info(f"\tOther tags: {other_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -353,8 +521,6 @@ def main(args):
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
@@ -381,9 +547,7 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--repo_id",
type=str,
@@ -401,9 +565,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -432,7 +594,29 @@ def setup_parser() -> argparse.ArgumentParser:
"--character_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags"
" / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる",
)
parser.add_argument(
"--meta_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags"
" / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる",
)
parser.add_argument(
"--model_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags"
" / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる",
)
parser.add_argument(
"--copyright_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags"
" / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる",
)
parser.add_argument(
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
@@ -442,9 +626,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--undesired_tags",
type=str,
@@ -454,20 +636,34 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
"--use_rating_tags",
action="store_true",
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
"--use_rating_tags_as_last_tag",
action="store_true",
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
"--use_quality_tags",
action="store_true",
help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_quality_tags_as_last_tag",
action="store_true",
help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first",
action="store_true",
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
@@ -512,5 +708,11 @@ if __name__ == "__main__":
args.general_threshold = args.thresh
if args.character_threshold is None:
args.character_threshold = args.thresh
if args.meta_threshold is None:
args.meta_threshold = args.thresh
if args.model_threshold is None:
args.model_threshold = args.thresh
if args.copyright_threshold is None:
args.copyright_threshold = args.thresh
main(args)