mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-13 16:07:18 +00:00
support new metadata in wd14tagger (WIP), fix typo
This commit is contained in:
@@ -1,7 +1,13 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import csv
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
import zipfile
|
||||
import tarfile
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -63,13 +69,90 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (image, img_path)
|
||||
return (image, img_path, size)
|
||||
|
||||
|
||||
class ArchiveImageLoader:
|
||||
def __init__(self, archive_paths: list[str], batch_size: int, debug: bool = False):
|
||||
self.archive_paths = archive_paths
|
||||
self.batch_size = batch_size
|
||||
self.debug = debug
|
||||
self.current_archive = None
|
||||
self.archive_index = 0
|
||||
self.image_index = 0
|
||||
self.files = None
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
images = []
|
||||
while len(images) < self.batch_size:
|
||||
if self.current_archive is None:
|
||||
if self.archive_index >= len(self.archive_paths):
|
||||
if len(images) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
break # return the remaining images
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
|
||||
|
||||
current_archive_path = self.archive_paths[self.archive_index]
|
||||
if current_archive_path.endswith(".zip"):
|
||||
self.current_archive = zipfile.ZipFile(current_archive_path)
|
||||
self.files = self.current_archive.namelist()
|
||||
elif current_archive_path.endswith(".tar"):
|
||||
self.current_archive = tarfile.open(current_archive_path, "r")
|
||||
self.files = self.current_archive.getnames()
|
||||
else:
|
||||
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
|
||||
|
||||
self.image_index = 0
|
||||
|
||||
# filter by image extensions
|
||||
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"found {len(self.files)} images in the archive")
|
||||
|
||||
while len(images) + len(new_images) < self.batch_size:
|
||||
if self.image_index >= len(self.files):
|
||||
break
|
||||
|
||||
file = self.files[self.image_index]
|
||||
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
|
||||
self.image_index += 1
|
||||
|
||||
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
|
||||
with archive.open(file) as f:
|
||||
image = Image.open(f).convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
return image, size
|
||||
|
||||
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
|
||||
|
||||
# wait for all new_images to load to close the archive
|
||||
new_images = [(image_path, future.result()) for image_path, future in new_images]
|
||||
|
||||
if self.image_index >= len(self.files):
|
||||
self.current_archive.close()
|
||||
self.current_archive = None
|
||||
self.archive_index += 1
|
||||
|
||||
images.extend(new_images)
|
||||
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
@@ -149,15 +232,19 @@ def main(args):
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(["OpenVINOExecutionProvider"]),
|
||||
provider_options=[{'device_type' : "GPU_FP32"}],
|
||||
provider_options=[{"device_type": "GPU_FP32"}],
|
||||
)
|
||||
else:
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=(
|
||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||
["CPUExecutionProvider"]
|
||||
["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else (
|
||||
["ROCMExecutionProvider"]
|
||||
if "ROCMExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -203,7 +290,9 @@ def main(args):
|
||||
tag_replacements = escaped_tag_replacements.split(";")
|
||||
for tag_replacement in tag_replacements:
|
||||
tags = tag_replacement.split(",") # source, target
|
||||
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
assert (
|
||||
len(tags) == 2
|
||||
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
|
||||
|
||||
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
|
||||
logger.info(f"replacing tag: {source} -> {target}")
|
||||
@@ -216,9 +305,15 @@ def main(args):
|
||||
rating_tags[rating_tags.index(source)] = target
|
||||
|
||||
# 画像を読み込む
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
if not args.load_archive:
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
else:
|
||||
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
|
||||
os.path.join(args.train_data_dir, "*.tar")
|
||||
)
|
||||
image_paths = [Path(archive_file) for archive_file in archive_files]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -231,19 +326,23 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(
|
||||
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
|
||||
images_metadata: Optional[dict[str, Any]],
|
||||
tags_index: Optional[int] = None,
|
||||
):
|
||||
imgs = np.array([im for _, im, _ in list_of_path_img_size])
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(list_of_path_img_size)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
@@ -265,7 +364,7 @@ def main(args):
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
if args.character_tags_first: # insert to the beginning
|
||||
combined_tags.insert(0, tag_name)
|
||||
else:
|
||||
combined_tags.append(tag_name)
|
||||
@@ -281,7 +380,7 @@ def main(args):
|
||||
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
|
||||
rating_tag_text = found_rating
|
||||
if args.use_rating_tags:
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
combined_tags.insert(0, found_rating) # insert to the beginning
|
||||
else:
|
||||
combined_tags.append(found_rating)
|
||||
|
||||
@@ -304,12 +403,24 @@ def main(args):
|
||||
tag_text = caption_separator.join(combined_tags)
|
||||
|
||||
if args.append_tags:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
existing_content = None
|
||||
if images_metadata is None:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is not None:
|
||||
tags = image_md.get("tags", None)
|
||||
if tags is not None:
|
||||
if tags_index is None and len(tags) > 0:
|
||||
existing_content = tags[-1]
|
||||
elif tags_index is not None and tags_index < len(tags):
|
||||
existing_content = tags[tags_index]
|
||||
|
||||
if existing_content is not None:
|
||||
# Split the content into tags and store them in a list
|
||||
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
|
||||
|
||||
@@ -319,19 +430,62 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if images_metadata is None:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is None:
|
||||
image_md = {"image_size": [image_size.width, image_size.height]}
|
||||
images_metadata[image_path] = image_md
|
||||
if "tags" not in image_md:
|
||||
image_md["tags"] = []
|
||||
if tags_index is None:
|
||||
image_md["tags"].append(tag_text)
|
||||
else:
|
||||
while len(image_md["tags"]) <= tags_index:
|
||||
image_md["tags"].append("")
|
||||
image_md["tags"][tags_index] = tag_text
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
|
||||
# load metadata if needed
|
||||
metadata = None
|
||||
if args.metadata is not None:
|
||||
if os.path.exists(args.metadata):
|
||||
logger.info(f"loading metadata file: {args.metadata}")
|
||||
with open(args.metadata, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
logger.info(f"metadata file not found: {args.metadata}, creating new metadata")
|
||||
metadata = {"format_version": "1.0.0", "images": {}}
|
||||
|
||||
images_metadata = metadata["images"]
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
use_loader = False
|
||||
if args.load_archive:
|
||||
loader = ArchiveImageLoader(image_paths, args.batch_size)
|
||||
use_loader = True
|
||||
elif args.max_data_loader_n_workers is not None:
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
data = torch.utils.data.DataLoader(
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
@@ -339,35 +493,37 @@ def main(args):
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
use_loader = True
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# make batch of image paths
|
||||
loader = []
|
||||
for i in range(0, len(image_paths), args.batch_size):
|
||||
loader.append(image_paths[i : i + args.batch_size])
|
||||
|
||||
b_imgs = []
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
for data_entry in tqdm(loader, smoothing=0.0):
|
||||
if use_loader:
|
||||
b_imgs = data_entry
|
||||
else:
|
||||
b_imgs = []
|
||||
for image_path in data_entry:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
b_imgs.append((image_path, image, size))
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs, images_metadata, args.tags_index)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.metadata is not None:
|
||||
logger.info(f"saving metadata file: {args.metadata}")
|
||||
with open(args.metadata, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -380,9 +536,7 @@ def main(args):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
@@ -400,9 +554,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
type=int,
|
||||
@@ -441,9 +593,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="debug mode"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
parser.add_argument(
|
||||
"--undesired_tags",
|
||||
type=str,
|
||||
@@ -453,20 +603,24 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
|
||||
)
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument(
|
||||
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
"--use_rating_tags",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
"--use_rating_tags_as_last_tag",
|
||||
action="store_true",
|
||||
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
"--character_tags_first",
|
||||
action="store_true",
|
||||
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--always_first_tags",
|
||||
@@ -495,6 +649,25 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tags_index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
|
||||
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_archive",
|
||||
action="store_true",
|
||||
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
|
||||
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_tolen_length,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user