mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add --use_rating_tags and --character_tags_first for WD Tagger
This commit is contained in:
@@ -130,10 +130,10 @@ def main(args):
|
|||||||
input_name = model.graph.input[0].name
|
input_name = model.graph.input[0].name
|
||||||
try:
|
try:
|
||||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||||
except:
|
except Exception:
|
||||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||||
|
|
||||||
if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0:
|
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
|
||||||
# some rebatch model may use 'N' as dynamic axes
|
# some rebatch model may use 'N' as dynamic axes
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||||
@@ -169,9 +169,9 @@ def main(args):
|
|||||||
|
|
||||||
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
l = [row for row in reader]
|
line = [row for row in reader]
|
||||||
header = l[0] # tag_id,name,category,count
|
header = line[0] # tag_id,name,category,count
|
||||||
rows = l[1:]
|
rows = line[1:]
|
||||||
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
||||||
|
|
||||||
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
|
||||||
@@ -228,20 +228,24 @@ def main(args):
|
|||||||
if tag_name not in undesired_tags:
|
if tag_name not in undesired_tags:
|
||||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||||
character_tag_text += caption_separator + tag_name
|
character_tag_text += caption_separator + tag_name
|
||||||
combined_tags.insert(0,tag_name) # insert to the beginning
|
if args.character_tags_first: # insert to the beginning
|
||||||
|
combined_tags.insert(0,tag_name)
|
||||||
|
else:
|
||||||
|
combined_tags.append(tag_name)
|
||||||
|
|
||||||
#最初の4つはratingなので無視する
|
#最初の4つはratingなので無視する
|
||||||
# First 4 labels are actually ratings: pick one with argmax
|
# First 4 labels are actually ratings: pick one with argmax
|
||||||
ratings_names = prob[:4]
|
if args.use_rating_tags:
|
||||||
rating_index = ratings_names.argmax()
|
ratings_names = prob[:4]
|
||||||
found_rating = rating_tags[rating_index]
|
rating_index = ratings_names.argmax()
|
||||||
if args.remove_underscore and len(found_rating) > 3:
|
found_rating = rating_tags[rating_index]
|
||||||
found_rating = found_rating.replace("_", " ")
|
if args.remove_underscore and len(found_rating) > 3:
|
||||||
|
found_rating = found_rating.replace("_", " ")
|
||||||
|
|
||||||
if found_rating not in undesired_tags:
|
if found_rating not in undesired_tags:
|
||||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||||
rating_tag_text = found_rating
|
rating_tag_text = found_rating
|
||||||
combined_tags.insert(0,found_rating) # insert to the beginning
|
combined_tags.insert(0,found_rating) # insert to the beginning
|
||||||
|
|
||||||
# 先頭のカンマを取る
|
# 先頭のカンマを取る
|
||||||
if len(general_tag_text) > 0:
|
if len(general_tag_text) > 0:
|
||||||
@@ -332,7 +336,9 @@ def main(args):
|
|||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument(
|
||||||
|
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo_id",
|
"--repo_id",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -350,7 +356,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||||
)
|
)
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_data_loader_n_workers",
|
"--max_data_loader_n_workers",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -389,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||||
)
|
)
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument(
|
||||||
|
"--debug", action="store_true", help="debug mode"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--undesired_tags",
|
"--undesired_tags",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -399,10 +409,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
|
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
|
||||||
)
|
)
|
||||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
parser.add_argument(
|
||||||
|
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--caption_separator",
|
"--caption_separator",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user