mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
feat: enhance image loading and processing in ImageLoadingPrepDataset with batch support and output options
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, errors
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -64,33 +65,40 @@ def preprocess_image(image: Image.Image) -> np.ndarray:
|
||||
|
||||
|
||||
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
def __init__(self, image_paths: list[str], batch_size: int):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = str(self.images[idx])
|
||||
def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]:
|
||||
image_index_start = batch_index * self.batch_size
|
||||
image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths))
|
||||
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
image = preprocess_image(image)
|
||||
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
batch_image_paths = []
|
||||
images = []
|
||||
image_sizes = []
|
||||
for idx in range(image_index_start, image_index_end):
|
||||
img_path = str(self.image_paths[idx])
|
||||
|
||||
return (image, img_path)
|
||||
try:
|
||||
image = Image.open(img_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
|
||||
batch_image_paths.append(img_path)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
|
||||
images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3))
|
||||
return batch_image_paths, images, image_sizes
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
The 'None's in the batch are removed.
|
||||
"""
|
||||
# Filter out all the Nones (corrupted examples)
|
||||
batch = list(filter(lambda x: x is not None, batch))
|
||||
def collate_fn_no_op(batch):
|
||||
"""Collate function that does nothing and returns the batch as is."""
|
||||
return batch
|
||||
|
||||
|
||||
@@ -311,6 +319,7 @@ def main(args):
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
image_paths = [str(ip) for ip in image_paths]
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -323,8 +332,11 @@ def main(args):
|
||||
if args.always_first_tags is not None:
|
||||
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[list[str]]:
|
||||
nonlocal args, default_format, model, ort_sess, input_name, tag_freq
|
||||
|
||||
imgs = path_imgs[1]
|
||||
result = {}
|
||||
|
||||
if args.onnx:
|
||||
# if len(imgs) < args.batch_size:
|
||||
@@ -333,12 +345,12 @@ def main(args):
|
||||
imgs = imgs.transpose(0, 3, 1, 2) # to NCHW
|
||||
imgs = imgs / 127.5 - 1.0
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
probs = probs[: len(imgs)] # remove padding
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs):
|
||||
combined_tags = []
|
||||
rating_tag_text = ""
|
||||
character_tag_text = ""
|
||||
@@ -390,51 +402,64 @@ def main(args):
|
||||
quality_max_prob = -1
|
||||
quality_tag = None
|
||||
character_tags = []
|
||||
for i, p in enumerate(prob):
|
||||
if i in tag_id_to_tag_mapping and p >= args.thresh:
|
||||
tag_name = tag_id_to_tag_mapping[i]
|
||||
category = tag_id_to_category_mapping[i]
|
||||
if tag_name in undesired_tags:
|
||||
continue
|
||||
|
||||
if category == "Rating":
|
||||
if p > rating_max_prob:
|
||||
rating_max_prob = p
|
||||
rating_tag = tag_name
|
||||
rating_tag_text = tag_name
|
||||
continue
|
||||
elif category == "Quality":
|
||||
if p > quality_max_prob:
|
||||
quality_max_prob = p
|
||||
quality_tag = tag_name
|
||||
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
|
||||
other_tag_text += caption_separator + tag_name
|
||||
continue
|
||||
min_thres = min(
|
||||
args.thresh,
|
||||
args.general_threshold,
|
||||
args.character_threshold,
|
||||
args.copyright_threshold,
|
||||
args.meta_threshold,
|
||||
args.model_threshold,
|
||||
)
|
||||
prob_indices = np.where(prob >= min_thres)[0]
|
||||
# for i, p in enumerate(prob):
|
||||
for i in prob_indices:
|
||||
if i not in tag_id_to_tag_mapping:
|
||||
continue
|
||||
p = prob[i]
|
||||
|
||||
if category == "General" and p >= args.general_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append((tag_name, p))
|
||||
elif category == "Character" and p >= args.character_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # we separate character tags
|
||||
character_tags.append((tag_name, p))
|
||||
else:
|
||||
combined_tags.append((tag_name, p))
|
||||
elif (
|
||||
(category == "Copyright" and p >= args.copyright_threshold)
|
||||
or (category == "Meta" and p >= args.meta_threshold)
|
||||
or (category == "Model" and p >= args.model_threshold)
|
||||
):
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
other_tag_text += f"{caption_separator}{tag_name} ({category})"
|
||||
tag_name = tag_id_to_tag_mapping[i]
|
||||
category = tag_id_to_category_mapping[i]
|
||||
if tag_name in undesired_tags:
|
||||
continue
|
||||
|
||||
if category == "Rating":
|
||||
if p > rating_max_prob:
|
||||
rating_max_prob = p
|
||||
rating_tag = tag_name
|
||||
rating_tag_text = tag_name
|
||||
continue
|
||||
elif category == "Quality":
|
||||
if p > quality_max_prob:
|
||||
quality_max_prob = p
|
||||
quality_tag = tag_name
|
||||
if args.use_quality_tags or args.use_quality_tags_as_last_tag:
|
||||
other_tag_text += caption_separator + tag_name
|
||||
continue
|
||||
|
||||
if category == "General" and p >= args.general_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append((tag_name, p))
|
||||
elif category == "Character" and p >= args.character_threshold:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += caption_separator + tag_name
|
||||
if args.character_tags_first: # we separate character tags
|
||||
character_tags.append((tag_name, p))
|
||||
else:
|
||||
combined_tags.append((tag_name, p))
|
||||
elif (
|
||||
(category == "Copyright" and p >= args.copyright_threshold)
|
||||
or (category == "Meta" and p >= args.meta_threshold)
|
||||
or (category == "Model" and p >= args.model_threshold)
|
||||
):
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
other_tag_text += f"{caption_separator}{tag_name} ({category})"
|
||||
combined_tags.append((tag_name, p))
|
||||
|
||||
# sort by probability
|
||||
combined_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
if character_tags:
|
||||
print(character_tags)
|
||||
character_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
combined_tags = character_tags + combined_tags
|
||||
combined_tags = [t[0] for t in combined_tags] # remove probability
|
||||
@@ -486,55 +511,79 @@ def main(args):
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if other_tag_text:
|
||||
logger.info(f"\tOther tags: {other_tag_text}")
|
||||
if not args.output_path:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
else:
|
||||
entry = {"tags": tag_text, "image_size": list(image_size)}
|
||||
result[image_path] = entry
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tRating tags: {rating_tag_text}")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
if other_tag_text:
|
||||
logger.info(f"\tOther tags: {other_tag_text}")
|
||||
|
||||
return result
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
dataset = ImageLoadingPrepDataset(image_paths)
|
||||
dataset = ImageLoadingPrepDataset(image_paths, args.batch_size)
|
||||
data = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
collate_fn=collate_fn_no_op,
|
||||
drop_last=False,
|
||||
)
|
||||
else:
|
||||
data = [[(None, ip)] for ip in image_paths]
|
||||
# data = [[(ip, None, None)] for ip in image_paths]
|
||||
data = [[]]
|
||||
for ip in image_paths:
|
||||
if len(data[-1]) >= args.batch_size:
|
||||
data.append([])
|
||||
data[-1].append((ip, None, None))
|
||||
|
||||
b_imgs = []
|
||||
results = {}
|
||||
for data_entry in tqdm(data, smoothing=0.0):
|
||||
for data in data_entry:
|
||||
if data is None:
|
||||
continue
|
||||
if data_entry is None or len(data_entry) == 0:
|
||||
continue
|
||||
|
||||
image, image_path = data
|
||||
if image is None:
|
||||
try:
|
||||
image = Image.open(image_path)
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
if data_entry[0][1] is None:
|
||||
# No preloaded image, need to load
|
||||
images = []
|
||||
image_sizes = []
|
||||
for image_path, _, _ in data_entry:
|
||||
image = Image.open(image_path)
|
||||
image_size = image.size
|
||||
image = preprocess_image(image)
|
||||
images.append(image)
|
||||
image_sizes.append(image_size)
|
||||
b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes)
|
||||
else:
|
||||
b_imgs = data_entry[0]
|
||||
|
||||
if len(b_imgs) >= args.batch_size:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
b_imgs.clear()
|
||||
r = run_batch(b_imgs)
|
||||
if args.output_path and r is not None:
|
||||
results.update(r)
|
||||
|
||||
if len(b_imgs) > 0:
|
||||
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs)
|
||||
if args.output_path:
|
||||
if args.output_path.endswith(".jsonl"):
|
||||
# optional JSONL metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
for image_path, entry in results.items():
|
||||
f.write(
|
||||
json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n"
|
||||
)
|
||||
else:
|
||||
# standard JSON metadata
|
||||
with open(args.output_path, "wt", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"captions saved to {args.output_path}")
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
@@ -572,6 +621,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパス(json形式)。このオプションが設定されている場合、キャプションはこのファイルに保存されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user