Update tag_images_by_wd14_tagger.py

add WDV3
This commit is contained in:
青龍聖者@bdsqlsz
2024-03-18 22:29:05 +08:00
parent f9317052ed
commit a7dff592d3

View File

@@ -86,23 +86,26 @@ def main(args):
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
else:
logger.info("using existing wd14 tagger model")
# 画像を読み込む
if args.onnx:
import torch
import onnx
import onnxruntime as ort