From bc586ce190e1e85adcd7d9734636fac068bc929e Mon Sep 17 00:00:00 2001 From: Disty0 Date: Fri, 29 Mar 2024 13:56:42 +0300 Subject: [PATCH] Add --use_rating_tags and --character_tags_first for WD Tagger --- finetune/tag_images_by_wd14_tagger.py | 56 ++++++++++++++++++--------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 4003210e..16a26179 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -130,10 +130,10 @@ def main(args): input_name = model.graph.input[0].name try: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value - except: + except Exception: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0: + if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0: # some rebatch model may use 'N' as dynamic axes logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" @@ -169,9 +169,9 @@ def main(args): with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] + line = [row for row in reader] + header = line[0] # tag_id,name,category,count + rows = line[1:] assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] @@ -228,20 +228,24 @@ def main(args): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 character_tag_text += caption_separator + tag_name - combined_tags.insert(0,tag_name) # insert to the beginning + if args.character_tags_first: # insert to the beginning + combined_tags.insert(0,tag_name) + else: + combined_tags.append(tag_name) #最初の4つはratingなので無視する # First 4 labels are actually ratings: pick one with argmax - ratings_names = prob[:4] - rating_index = ratings_names.argmax() - found_rating = rating_tags[rating_index] - if args.remove_underscore and len(found_rating) > 3: - found_rating = found_rating.replace("_", " ") + if args.use_rating_tags: + ratings_names = prob[:4] + rating_index = ratings_names.argmax() + found_rating = rating_tags[rating_index] + if args.remove_underscore and len(found_rating) > 3: + found_rating = found_rating.replace("_", " ") - if found_rating not in undesired_tags: - tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 - rating_tag_text = found_rating - combined_tags.insert(0,found_rating) # insert to the beginning + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + combined_tags.insert(0,found_rating) # insert to the beginning # 先頭のカンマを取る if len(general_tag_text) > 0: @@ -332,7 +336,9 @@ def main(args): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ" + ) parser.add_argument( "--repo_id", type=str, @@ -350,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument( + "--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ" + ) parser.add_argument( "--max_data_loader_n_workers", type=int, @@ -389,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", ) - parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument( + "--debug", action="store_true", help="debug mode" + ) parser.add_argument( "--undesired_tags", type=str, @@ -399,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する" ) - parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") + parser.add_argument( + "--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する" + ) parser.add_argument( "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" ) + parser.add_argument( + "--use_rating_tags", action="store_true", help="Adds rating tags as the first tag", + ) + parser.add_argument( + "--character_tags_first", action="store_true", help="Always inserts character tags before the general tags", + ) parser.add_argument( "--caption_separator", type=str,