mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Rating support for WD Tagger
This commit is contained in:
@@ -174,8 +174,9 @@ def main(args):
|
|||||||
rows = l[1:]
|
rows = l[1:]
|
||||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||||
|
|
||||||
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||||
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
|
||||||
|
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
|
|
||||||
@@ -202,17 +203,13 @@ def main(args):
|
|||||||
probs = probs.numpy()
|
probs = probs.numpy()
|
||||||
|
|
||||||
for (image_path, _), prob in zip(path_imgs, probs):
|
for (image_path, _), prob in zip(path_imgs, probs):
|
||||||
# 最初の4つはratingなので無視する
|
combined_tags = []
|
||||||
# # First 4 labels are actually ratings: pick one with argmax
|
rating_tag_text = ""
|
||||||
# ratings_names = label_names[:4]
|
character_tag_text = ""
|
||||||
# rating_index = ratings_names["probs"].argmax()
|
general_tag_text = ""
|
||||||
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
|
||||||
|
|
||||||
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||||
# Everything else is tags: pick any where prediction confidence > threshold
|
# Everything else is tags: pick any where prediction confidence > threshold
|
||||||
combined_tags = []
|
|
||||||
general_tag_text = ""
|
|
||||||
character_tag_text = ""
|
|
||||||
for i, p in enumerate(prob[4:]):
|
for i, p in enumerate(prob[4:]):
|
||||||
if i < len(general_tags) and p >= args.general_threshold:
|
if i < len(general_tags) and p >= args.general_threshold:
|
||||||
tag_name = general_tags[i]
|
tag_name = general_tags[i]
|
||||||
@@ -231,7 +228,20 @@ def main(args):
|
|||||||
if tag_name not in undesired_tags:
|
if tag_name not in undesired_tags:
|
||||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||||
character_tag_text += caption_separator + tag_name
|
character_tag_text += caption_separator + tag_name
|
||||||
combined_tags.append(tag_name)
|
combined_tags.insert(0,tag_name) # insert to the beggining
|
||||||
|
|
||||||
|
#最初の4つはratingなので無視する
|
||||||
|
# First 4 labels are actually ratings: pick one with argmax
|
||||||
|
ratings_names = prob[:4]
|
||||||
|
rating_index = ratings_names.argmax()
|
||||||
|
found_rating = rating_tags[rating_index]
|
||||||
|
if args.remove_underscore and len(found_rating) > 3:
|
||||||
|
found_rating = found_rating.replace("_", " ")
|
||||||
|
|
||||||
|
if found_rating not in undesired_tags:
|
||||||
|
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||||
|
rating_tag_text = found_rating
|
||||||
|
combined_tags.insert(0,found_rating) # insert to the beggining
|
||||||
|
|
||||||
# 先頭のカンマを取る
|
# 先頭のカンマを取る
|
||||||
if len(general_tag_text) > 0:
|
if len(general_tag_text) > 0:
|
||||||
@@ -264,6 +274,7 @@ def main(args):
|
|||||||
if args.debug:
|
if args.debug:
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info(f"{image_path}:")
|
logger.info(f"{image_path}:")
|
||||||
|
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user