Rating support for WD Tagger

This commit is contained in:
Disty0
2024-03-27 21:53:40 +03:00
parent b86af6798d
commit dd9763be31

View File

@@ -174,8 +174,9 @@ def main(args):
rows = l[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# 画像を読み込む
@@ -202,17 +203,13 @@ def main(args):
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
character_tag_text = ""
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
@@ -231,7 +228,20 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
combined_tags.insert(0,tag_name) # insert to the beggining
#最初の4つはratingなので無視する
# First 4 labels are actually ratings: pick one with argmax
ratings_names = prob[:4]
rating_index = ratings_names.argmax()
found_rating = rating_tags[rating_index]
if args.remove_underscore and len(found_rating) > 3:
found_rating = found_rating.replace("_", " ")
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
combined_tags.insert(0,found_rating) # insert to the beggining
# 先頭のカンマを取る
if len(general_tag_text) > 0:
@@ -264,6 +274,7 @@ def main(args):
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")