Add --use_rating_tags and --character_tags_first for WD Tagger

This commit is contained in:
Disty0
2024-03-29 13:56:42 +03:00
parent 4012fd24f6
commit bc586ce190

View File

@@ -130,10 +130,10 @@ def main(args):
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0:
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
@@ -169,9 +169,9 @@ def main(args):
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
@@ -228,20 +228,24 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.insert(0,tag_name) # insert to the beginning
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0,tag_name)
else:
combined_tags.append(tag_name)
#最初の4つはratingなので無視する
# First 4 labels are actually ratings: pick one with argmax
ratings_names = prob[:4]
rating_index = ratings_names.argmax()
found_rating = rating_tags[rating_index]
if args.remove_underscore and len(found_rating) > 3:
found_rating = found_rating.replace("_", " ")
if args.use_rating_tags:
ratings_names = prob[:4]
rating_index = ratings_names.argmax()
found_rating = rating_tags[rating_index]
if args.remove_underscore and len(found_rating) > 3:
found_rating = found_rating.replace("_", " ")
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
combined_tags.insert(0,found_rating) # insert to the beginning
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
combined_tags.insert(0,found_rating) # insert to the beginning
# 先頭のカンマを取る
if len(general_tag_text) > 0:
@@ -332,7 +336,9 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -350,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -389,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
@@ -399,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags",
)
parser.add_argument(
"--caption_separator",
type=str,