mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Update tag_images_by_wd14_tagger.py
add WDV3
This commit is contained in:
@@ -86,23 +86,26 @@ def main(args):
|
|||||||
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||||
files = FILES
|
files = FILES
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
|
files = ["selected_tags.csv"]
|
||||||
files += FILES_ONNX
|
files += FILES_ONNX
|
||||||
|
else:
|
||||||
|
for file in SUB_DIR_FILES:
|
||||||
|
hf_hub_download(
|
||||||
|
args.repo_id,
|
||||||
|
file,
|
||||||
|
subfolder=SUB_DIR,
|
||||||
|
cache_dir=os.path.join(args.model_dir, SUB_DIR),
|
||||||
|
force_download=True,
|
||||||
|
force_filename=file,
|
||||||
|
)
|
||||||
for file in files:
|
for file in files:
|
||||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||||
for file in SUB_DIR_FILES:
|
|
||||||
hf_hub_download(
|
|
||||||
args.repo_id,
|
|
||||||
file,
|
|
||||||
subfolder=SUB_DIR,
|
|
||||||
cache_dir=os.path.join(args.model_dir, SUB_DIR),
|
|
||||||
force_download=True,
|
|
||||||
force_filename=file,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info("using existing wd14 tagger model")
|
logger.info("using existing wd14 tagger model")
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
|
import torch
|
||||||
import onnx
|
import onnx
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user