mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add error message if model.onnx doesn't exist
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -19,6 +18,7 @@ IMAGE_SIZE = 448
|
||||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||
FILES_ONNX = ["model.onnx"]
|
||||
SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
@@ -80,9 +80,10 @@ def main(args):
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
FILES.append("model.onnx")
|
||||
for file in FILES:
|
||||
files += FILES_ONNX
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
@@ -104,18 +105,29 @@ def main(args):
|
||||
onnx_path = f"{args.model_dir}/model.onnx"
|
||||
print("Running wd14 tagger with onnx")
|
||||
print(f"loading onnx model: {onnx_path}")
|
||||
|
||||
if not os.path.exists(onnx_path):
|
||||
raise Exception(
|
||||
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
|
||||
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
|
||||
)
|
||||
|
||||
model = onnx.load(onnx_path)
|
||||
input_name = model.graph.input[0].name
|
||||
try:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||
except:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||
|
||||
if args.batch_size != batch_size and type(batch_size) != str:
|
||||
# some rebatch model may use 'N' as dynamic axes
|
||||
print(
|
||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
|
||||
del model
|
||||
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=["CUDAExecutionProvider"]
|
||||
@@ -154,7 +166,10 @@ def main(args):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
if args.onnx:
|
||||
if len(imgs) < args.batch_size:
|
||||
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
@@ -333,7 +348,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||
)
|
||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user