diff --git a/README.md b/README.md index f0cad611..d0320403 100644 --- a/README.md +++ b/README.md @@ -260,7 +260,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end. - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order. - The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例). - +- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds! + - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`. +- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`. - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 @@ -269,6 +271,11 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。 - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。 - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。 +- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。 + - Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。 +- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。 + + #### Example of dataset settings / データセット設定の記述例: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index e63ec3eb..401c6d1e 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -12,8 +12,10 @@ from tqdm import tqdm import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # from wd14 tagger @@ -79,10 +81,15 @@ def collate_fn_remove_corrupted(batch): def main(args): + # model location is model_dir + repo_id + # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash + model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: + if not os.path.exists(model_location) or args.force_download: + os.makedirs(args.model_dir, exist_ok=True) logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") files = FILES if args.onnx: @@ -94,12 +101,12 @@ def main(args): args.repo_id, file, subfolder=SUB_DIR, - cache_dir=os.path.join(args.model_dir, SUB_DIR), + cache_dir=os.path.join(model_location, SUB_DIR), force_download=True, force_filename=file, ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) else: logger.info("using existing wd14 tagger model") @@ -109,7 +116,7 @@ def main(args): import onnx import onnxruntime as ort - onnx_path = f"{args.model_dir}/model.onnx" + onnx_path = f"{model_location}/model.onnx" logger.info("Running wd14 tagger with onnx") logger.info(f"loading onnx model: {onnx_path}") @@ -126,7 +133,7 @@ def main(args): except: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param - if args.batch_size != batch_size and type(batch_size) != str: + if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0: # some rebatch model may use 'N' as dynamic axes logger.warning( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" @@ -137,19 +144,19 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, - providers=["CUDAExecutionProvider"] - if "CUDAExecutionProvider" in ort.get_available_providers() - else ["CPUExecutionProvider"], + providers=( + ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"] + ), ) else: from tensorflow.keras.models import load_model - model = load_model(f"{args.model_dir}") + model = load_model(f"{model_location}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f) l = [row for row in reader] header = l[0] # tag_id,name,category,count @@ -175,8 +182,8 @@ def main(args): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - if len(imgs) < args.batch_size: - imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + # if len(imgs) < args.batch_size: + # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -317,7 +324,9 @@ def setup_parser() -> argparse.ArgumentParser: help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ", ) parser.add_argument( - "--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします" + "--force_download", + action="store_true", + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( @@ -332,8 +341,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", ) - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") + parser.add_argument( + "--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子" + ) + parser.add_argument( + "--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値" + ) parser.add_argument( "--general_threshold", type=float, @@ -346,7 +359,9 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", ) - parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") + parser.add_argument( + "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" + ) parser.add_argument( "--remove_underscore", action="store_true", @@ -359,9 +374,13 @@ def setup_parser() -> argparse.ArgumentParser: default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument( + "--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する" + ) parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") - parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" + ) parser.add_argument( "--caption_separator", type=str,