diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index c19ae974..1d49afc7 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -174,8 +174,9 @@ def main(args): rows = l[1:] assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - general_tags = [row[1] for row in rows[1:] if row[2] == "0"] - character_tags = [row[1] for row in rows[1:] if row[2] == "4"] + rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] + general_tags = [row[1] for row in rows[0:] if row[2] == "0"] + character_tags = [row[1] for row in rows[0:] if row[2] == "4"] # 画像を読み込む @@ -202,17 +203,13 @@ def main(args): probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + combined_tags = [] + rating_tag_text = "" + character_tag_text = "" + general_tag_text = "" # それ以降はタグなのでconfidenceがthresholdより高いものを追加する # Everything else is tags: pick any where prediction confidence > threshold - combined_tags = [] - general_tag_text = "" - character_tag_text = "" for i, p in enumerate(prob[4:]): if i < len(general_tags) and p >= args.general_threshold: tag_name = general_tags[i] @@ -231,7 +228,20 @@ def main(args): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name - combined_tags.append(tag_name) + combined_tags.insert(0,tag_name) # insert to the beggining + + #最初の4つはratingなので無視する + # First 4 labels are actually ratings: pick one with argmax + ratings_names = prob[:4] + rating_index = ratings_names.argmax() + found_rating = rating_tags[rating_index] + if args.remove_underscore and len(found_rating) > 3: + found_rating = found_rating.replace("_", " ") + + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + combined_tags.insert(0,found_rating) # insert to the beggining # 先頭のカンマを取る if len(general_tag_text) > 0: @@ -264,6 +274,7 @@ def main(args): if args.debug: logger.info("") logger.info(f"{image_path}:") + logger.info(f"\tRating tags: {rating_tag_text}") logger.info(f"\tCharacter tags: {character_tag_text}") logger.info(f"\tGeneral tags: {general_tag_text}")