add onnx to wd14 tagger

This commit is contained in:
Isotr0py
2023-10-08 20:31:10 +08:00
parent 2d87bb648f
commit 70fe7e18be
2 changed files with 47 additions and 12 deletions

View File

@@ -2,16 +2,15 @@ import argparse
import csv import csv
import glob import glob
import os import os
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
# from wd14 tagger # from wd14 tagger
@@ -81,6 +80,8 @@ def main(args):
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download: if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
if args.onnx:
FILES.append("model.onnx")
for file in FILES: for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES: for file in SUB_DIR_FILES:
@@ -96,7 +97,35 @@ def main(args):
print("using existing wd14 tagger model") print("using existing wd14 tagger model")
# 画像を読み込む # 画像を読み込む
model = load_model(args.model_dir) if args.onnx:
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
ort_sess = ort.InferenceSession(
model.SerializeToString(),
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ # 依存ライブラリを増やしたくないので自力で読むよ
@@ -124,8 +153,11 @@ def main(args):
def run_batch(path_imgs): def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs]) imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False) if args.onnx:
probs = probs.numpy() probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs): for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する # 最初の4つはratingなので無視する
@@ -283,6 +315,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
) )
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
return parser return parser

View File

@@ -19,8 +19,10 @@ huggingface-hub==0.15.1
# requests==2.28.2 # requests==2.28.2
# timm==0.6.12 # timm==0.6.12
# fairscale==0.4.13 # fairscale==0.4.13
# for WD14 captioning # for WD14 captioning (tensroflow or onnx)
# tensorflow==2.10.1 # tensorflow==2.10.1
# onnx==1.14.1
# onnxruntime==1.16.0
# open clip for SDXL # open clip for SDXL
open-clip-torch==2.20.0 open-clip-torch==2.20.0
# for kohya_ss library # for kohya_ss library