diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 965edd7e..fbf328e8 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -160,7 +160,9 @@ def main(args): tag_freq = {} - undesired_tags = set(args.undesired_tags.split(",")) + caption_separator = args.caption_separator + stripped_caption_separator = caption_separator.strip() + undesired_tags = set(args.undesired_tags.split(stripped_caption_separator)) def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) @@ -194,7 +196,7 @@ def main(args): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += ", " + tag_name + general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) elif i >= len(general_tags) and p >= args.character_threshold: tag_name = character_tags[i - len(general_tags)] @@ -203,18 +205,18 @@ def main(args): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += ", " + tag_name + character_tag_text += caption_separator + tag_name combined_tags.append(tag_name) # 先頭のカンマを取る if len(general_tag_text) > 0: - general_tag_text = general_tag_text[2:] + general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: - character_tag_text = character_tag_text[2:] + character_tag_text = character_tag_text[len(caption_separator) :] caption_file = os.path.splitext(image_path)[0] + args.caption_extension - tag_text = ", ".join(combined_tags) + tag_text = caption_separator.join(combined_tags) if args.append_tags: # Check if file exists @@ -224,13 +226,13 @@ def main(args): existing_content = f.read().strip("\n") # Remove newlines # Split the content into tags and store them in a list - existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()] # Check and remove repeating tags in tag_text new_tags = [tag for tag in combined_tags if tag not in existing_tags] # Create new tag_text - tag_text = ", ".join(existing_tags + new_tags) + tag_text = caption_separator.join(existing_tags + new_tags) with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") @@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--caption_separator", + type=str, + default=", ", + help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください", + ) return parser