mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add error message if model.onnx doesn't exist
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -19,6 +18,7 @@ IMAGE_SIZE = 448
|
|||||||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||||
|
FILES_ONNX = ["model.onnx"]
|
||||||
SUB_DIR = "variables"
|
SUB_DIR = "variables"
|
||||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||||
CSV_FILE = FILES[-1]
|
CSV_FILE = FILES[-1]
|
||||||
@@ -80,9 +80,10 @@ def main(args):
|
|||||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||||
if not os.path.exists(args.model_dir) or args.force_download:
|
if not os.path.exists(args.model_dir) or args.force_download:
|
||||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||||
|
files = FILES
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
FILES.append("model.onnx")
|
files += FILES_ONNX
|
||||||
for file in FILES:
|
for file in files:
|
||||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||||
for file in SUB_DIR_FILES:
|
for file in SUB_DIR_FILES:
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
@@ -104,18 +105,29 @@ def main(args):
|
|||||||
onnx_path = f"{args.model_dir}/model.onnx"
|
onnx_path = f"{args.model_dir}/model.onnx"
|
||||||
print("Running wd14 tagger with onnx")
|
print("Running wd14 tagger with onnx")
|
||||||
print(f"loading onnx model: {onnx_path}")
|
print(f"loading onnx model: {onnx_path}")
|
||||||
|
|
||||||
|
if not os.path.exists(onnx_path):
|
||||||
|
raise Exception(
|
||||||
|
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
|
||||||
|
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
|
||||||
|
)
|
||||||
|
|
||||||
model = onnx.load(onnx_path)
|
model = onnx.load(onnx_path)
|
||||||
input_name = model.graph.input[0].name
|
input_name = model.graph.input[0].name
|
||||||
try:
|
try:
|
||||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||||
except:
|
except:
|
||||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||||
|
|
||||||
if args.batch_size != batch_size and type(batch_size) != str:
|
if args.batch_size != batch_size and type(batch_size) != str:
|
||||||
# some rebatch model may use 'N' as dynamic axes
|
# some rebatch model may use 'N' as dynamic axes
|
||||||
print(
|
print(
|
||||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||||
)
|
)
|
||||||
args.batch_size = batch_size
|
args.batch_size = batch_size
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
ort_sess = ort.InferenceSession(
|
ort_sess = ort.InferenceSession(
|
||||||
onnx_path,
|
onnx_path,
|
||||||
providers=["CUDAExecutionProvider"]
|
providers=["CUDAExecutionProvider"]
|
||||||
@@ -154,7 +166,10 @@ def main(args):
|
|||||||
imgs = np.array([im for _, im in path_imgs])
|
imgs = np.array([im for _, im in path_imgs])
|
||||||
|
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
|
if len(imgs) < args.batch_size:
|
||||||
|
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||||
|
probs = probs[: len(path_imgs)]
|
||||||
else:
|
else:
|
||||||
probs = model(imgs, training=False)
|
probs = model(imgs, training=False)
|
||||||
probs = probs.numpy()
|
probs = probs.numpy()
|
||||||
@@ -333,7 +348,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||||
)
|
)
|
||||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
|
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||||
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
Reference in New Issue
Block a user