mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add onnx to wd14 tagger
This commit is contained in:
@@ -2,16 +2,15 @@ import argparse
|
|||||||
import csv
|
import csv
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
# from wd14 tagger
|
# from wd14 tagger
|
||||||
@@ -81,6 +80,8 @@ def main(args):
|
|||||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||||
if not os.path.exists(args.model_dir) or args.force_download:
|
if not os.path.exists(args.model_dir) or args.force_download:
|
||||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||||
|
if args.onnx:
|
||||||
|
FILES.append("model.onnx")
|
||||||
for file in FILES:
|
for file in FILES:
|
||||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||||
for file in SUB_DIR_FILES:
|
for file in SUB_DIR_FILES:
|
||||||
@@ -96,7 +97,35 @@ def main(args):
|
|||||||
print("using existing wd14 tagger model")
|
print("using existing wd14 tagger model")
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
model = load_model(args.model_dir)
|
if args.onnx:
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
onnx_path = f"{args.model_dir}/model.onnx"
|
||||||
|
print("Running wd14 tagger with onnx")
|
||||||
|
print(f"loading onnx model: {onnx_path}")
|
||||||
|
model = onnx.load(onnx_path)
|
||||||
|
input_name = model.graph.input[0].name
|
||||||
|
try:
|
||||||
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||||
|
except:
|
||||||
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||||
|
if args.batch_size != batch_size and type(batch_size) != str:
|
||||||
|
# some rebatch model may use 'N' as dynamic axes
|
||||||
|
print(
|
||||||
|
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||||
|
)
|
||||||
|
args.batch_size = batch_size
|
||||||
|
ort_sess = ort.InferenceSession(
|
||||||
|
model.SerializeToString(),
|
||||||
|
providers=["CUDAExecutionProvider"]
|
||||||
|
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||||
|
else ["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
|
||||||
|
model = load_model(f"{args.model_dir}")
|
||||||
|
|
||||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||||
@@ -124,6 +153,9 @@ def main(args):
|
|||||||
def run_batch(path_imgs):
|
def run_batch(path_imgs):
|
||||||
imgs = np.array([im for _, im in path_imgs])
|
imgs = np.array([im for _, im in path_imgs])
|
||||||
|
|
||||||
|
if args.onnx:
|
||||||
|
probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy
|
||||||
|
else:
|
||||||
probs = model(imgs, training=False)
|
probs = model(imgs, training=False)
|
||||||
probs = probs.numpy()
|
probs = probs.numpy()
|
||||||
|
|
||||||
@@ -283,6 +315,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||||
)
|
)
|
||||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||||
|
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ huggingface-hub==0.15.1
|
|||||||
# requests==2.28.2
|
# requests==2.28.2
|
||||||
# timm==0.6.12
|
# timm==0.6.12
|
||||||
# fairscale==0.4.13
|
# fairscale==0.4.13
|
||||||
# for WD14 captioning
|
# for WD14 captioning (tensroflow or onnx)
|
||||||
# tensorflow==2.10.1
|
# tensorflow==2.10.1
|
||||||
|
# onnx==1.14.1
|
||||||
|
# onnxruntime==1.16.0
|
||||||
# open clip for SDXL
|
# open clip for SDXL
|
||||||
open-clip-torch==2.20.0
|
open-clip-torch==2.20.0
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
|
|||||||
Reference in New Issue
Block a user