support new metadata in wd14tagger (WIP), fix typo

This commit is contained in:
Kohya S
2024-11-28 21:05:17 +09:00
parent 665c04e649
commit 2238b94e7b
2 changed files with 237 additions and 64 deletions

View File

@@ -1,7 +1,13 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
import csv
import glob
import json
import os
from pathlib import Path
from typing import Any, Optional, Union
import zipfile
import tarfile
import cv2
import numpy as np
@@ -63,13 +69,90 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
size = image.size
image = preprocess_image(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (image, img_path)
return (image, img_path, size)
class ArchiveImageLoader:
def __init__(self, archive_paths: list[str], batch_size: int, debug: bool = False):
self.archive_paths = archive_paths
self.batch_size = batch_size
self.debug = debug
self.current_archive = None
self.archive_index = 0
self.image_index = 0
self.files = None
self.executor = ThreadPoolExecutor()
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
def __iter__(self):
return self
def __next__(self):
images = []
while len(images) < self.batch_size:
if self.current_archive is None:
if self.archive_index >= len(self.archive_paths):
if len(images) == 0:
raise StopIteration
else:
break # return the remaining images
if self.debug:
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
current_archive_path = self.archive_paths[self.archive_index]
if current_archive_path.endswith(".zip"):
self.current_archive = zipfile.ZipFile(current_archive_path)
self.files = self.current_archive.namelist()
elif current_archive_path.endswith(".tar"):
self.current_archive = tarfile.open(current_archive_path, "r")
self.files = self.current_archive.getnames()
else:
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
self.image_index = 0
# filter by image extensions
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
file = self.files[self.image_index]
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
with archive.open(file) as f:
image = Image.open(f).convert("RGB")
size = image.size
image = preprocess_image(image)
return image, size
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
# wait for all new_images to load to close the archive
new_images = [(image_path, future.result()) for image_path, future in new_images]
if self.image_index >= len(self.files):
self.current_archive.close()
self.current_archive = None
self.archive_index += 1
images.extend(new_images)
return [(image_path, image, size) for image_path, (image, size) in images]
def collate_fn_remove_corrupted(batch):
@@ -149,15 +232,19 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
provider_options=[{"device_type": "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
),
)
else:
@@ -203,7 +290,9 @@ def main(args):
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
assert (
len(tags) == 2
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
@@ -216,9 +305,15 @@ def main(args):
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
tag_freq = {}
@@ -231,19 +326,23 @@ def main(args):
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
def run_batch(
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
tags_index: Optional[int] = None,
):
imgs = np.array([im for _, im, _ in list_of_path_img_size])
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
probs = probs[: len(list_of_path_img_size)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
@@ -265,7 +364,7 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
@@ -281,7 +380,7 @@ def main(args):
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
@@ -304,12 +403,24 @@ def main(args):
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
existing_content = None
if images_metadata is None:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
else:
image_md = images_metadata.get(image_path, None)
if image_md is not None:
tags = image_md.get("tags", None)
if tags is not None:
if tags_index is None and len(tags) > 0:
existing_content = tags[-1]
elif tags_index is not None and tags_index < len(tags):
existing_content = tags[tags_index]
if existing_content is not None:
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
@@ -319,19 +430,62 @@ def main(args):
# Create new tag_text
tag_text = caption_separator.join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": [image_size.width, image_size.height]}
images_metadata[image_path] = image_md
if "tags" not in image_md:
image_md["tags"] = []
if tags_index is None:
image_md["tags"].append(tag_text)
else:
while len(image_md["tags"]) <= tags_index:
image_md["tags"].append("")
image_md["tags"][tags_index] = tag_text
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# load metadata if needed
metadata = None
if args.metadata is not None:
if os.path.exists(args.metadata):
logger.info(f"loading metadata file: {args.metadata}")
with open(args.metadata, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
if major > 1 or (major == 1 and minor > 0):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
logger.info(f"metadata file not found: {args.metadata}, creating new metadata")
metadata = {"format_version": "1.0.0", "images": {}}
images_metadata = metadata["images"]
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = ArchiveImageLoader(image_paths, args.batch_size)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(
loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
@@ -339,35 +493,37 @@ def main(args):
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
use_loader = True
else:
data = [[(None, ip)] for ip in image_paths]
# make batch of image paths
loader = []
for i in range(0, len(image_paths), args.batch_size):
loader.append(image_paths[i : i + args.batch_size])
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
for data_entry in tqdm(loader, smoothing=0.0):
if use_loader:
b_imgs = data_entry
else:
b_imgs = []
for image_path in data_entry:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
size = image.size
image = preprocess_image(image)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
b_imgs.append((image_path, image, size))
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.tags_index)
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
@@ -380,9 +536,7 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--repo_id",
type=str,
@@ -400,9 +554,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -441,9 +593,7 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--undesired_tags",
type=str,
@@ -453,20 +603,24 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
"--use_rating_tags",
action="store_true",
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
"--use_rating_tags_as_last_tag",
action="store_true",
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
"--character_tags_first",
action="store_true",
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
@@ -495,6 +649,25 @@ def setup_parser() -> argparse.ArgumentParser:
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
)
parser.add_argument(
"--tags_index",
type=int,
default=None,
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
)
parser.add_argument(
"--load_archive",
action="store_true",
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
)
return parser

View File

@@ -84,7 +84,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_tolen_length,
args.max_token_length,
is_weighted=args.weighted_captions,
)
else: