From 406511c333d99286f19e9a5bf2de55bccfd5302b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 17:08:58 +0900 Subject: [PATCH] add error message if model.onnx doesn't exist --- finetune/tag_images_by_wd14_tagger.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index ffe94e7d..965edd7e 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,6 +1,5 @@ import argparse import csv -import glob import os from pathlib import Path @@ -19,6 +18,7 @@ IMAGE_SIZE = 448 # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] +FILES_ONNX = ["model.onnx"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] @@ -80,9 +80,10 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + files = FILES if args.onnx: - FILES.append("model.onnx") - for file in FILES: + files += FILES_ONNX + for file in files: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: hf_hub_download( @@ -104,18 +105,29 @@ def main(args): onnx_path = f"{args.model_dir}/model.onnx" print("Running wd14 tagger with onnx") print(f"loading onnx model: {onnx_path}") + + if not os.path.exists(onnx_path): + raise Exception( + f"onnx model not found: {onnx_path}, please redownload the model with --force_download" + + " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください" + ) + model = onnx.load(onnx_path) input_name = model.graph.input[0].name try: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value except: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param + if args.batch_size != batch_size and type(batch_size) != str: # some rebatch model may use 'N' as dynamic axes print( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size + + del model + ort_sess = ort.InferenceSession( onnx_path, providers=["CUDAExecutionProvider"] @@ -154,7 +166,10 @@ def main(args): imgs = np.array([im for _, im in path_imgs]) if args.onnx: + if len(imgs) < args.batch_size: + imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy + probs = probs[: len(path_imgs)] else: probs = model(imgs, training=False) probs = probs.numpy() @@ -333,7 +348,7 @@ def setup_parser() -> argparse.ArgumentParser: help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") - parser.add_argument("--onnx", action="store_true", help="use onnx model for inference") + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") return parser