diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 17a75bd7..9def06da 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,13 +1,14 @@ import argparse import csv import json +import math import os from pathlib import Path +from typing import Optional -import cv2 import numpy as np import torch -from huggingface_hub import hf_hub_download, errors +from huggingface_hub import hf_hub_download from PIL import Image from tqdm import tqdm @@ -64,33 +65,40 @@ def preprocess_image(image: Image.Image) -> np.ndarray: class ImageLoadingPrepDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths: list[str], batch_size: int): + self.image_paths = image_paths + self.batch_size = batch_size def __len__(self): - return len(self.images) + return math.ceil(len(self.image_paths) / self.batch_size) - def __getitem__(self, idx): - img_path = str(self.images[idx]) + def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]: + image_index_start = batch_index * self.batch_size + image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths)) - try: - image = Image.open(img_path) - image = preprocess_image(image) - # tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None + batch_image_paths = [] + images = [] + image_sizes = [] + for idx in range(image_index_start, image_index_end): + img_path = str(self.image_paths[idx]) - return (image, img_path) + try: + image = Image.open(img_path) + image_size = image.size + image = preprocess_image(image) + + batch_image_paths.append(img_path) + images.append(image) + image_sizes.append(image_size) + except Exception as e: + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + + images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3)) + return batch_image_paths, images, image_sizes -def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) +def collate_fn_no_op(batch): + """Collate function that does nothing and returns the batch as is.""" return batch @@ -311,6 +319,7 @@ def main(args): train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) logger.info(f"found {len(image_paths)} images.") + image_paths = [str(ip) for ip in image_paths] tag_freq = {} @@ -323,8 +332,11 @@ def main(args): if args.always_first_tags is not None: always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""] - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) + def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[list[str]]: + nonlocal args, default_format, model, ort_sess, input_name, tag_freq + + imgs = path_imgs[1] + result = {} if args.onnx: # if len(imgs) < args.batch_size: @@ -333,12 +345,12 @@ def main(args): imgs = imgs.transpose(0, 3, 1, 2) # to NCHW imgs = imgs / 127.5 - 1.0 probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy - probs = probs[: len(path_imgs)] + probs = probs[: len(imgs)] # remove padding else: probs = model(imgs, training=False) probs = probs.numpy() - for (image_path, _), prob in zip(path_imgs, probs): + for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs): combined_tags = [] rating_tag_text = "" character_tag_text = "" @@ -390,51 +402,64 @@ def main(args): quality_max_prob = -1 quality_tag = None character_tags = [] - for i, p in enumerate(prob): - if i in tag_id_to_tag_mapping and p >= args.thresh: - tag_name = tag_id_to_tag_mapping[i] - category = tag_id_to_category_mapping[i] - if tag_name in undesired_tags: - continue - if category == "Rating": - if p > rating_max_prob: - rating_max_prob = p - rating_tag = tag_name - rating_tag_text = tag_name - continue - elif category == "Quality": - if p > quality_max_prob: - quality_max_prob = p - quality_tag = tag_name - if args.use_quality_tags or args.use_quality_tags_as_last_tag: - other_tag_text += caption_separator + tag_name - continue + min_thres = min( + args.thresh, + args.general_threshold, + args.character_threshold, + args.copyright_threshold, + args.meta_threshold, + args.model_threshold, + ) + prob_indices = np.where(prob >= min_thres)[0] + # for i, p in enumerate(prob): + for i in prob_indices: + if i not in tag_id_to_tag_mapping: + continue + p = prob[i] - if category == "General" and p >= args.general_threshold: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += caption_separator + tag_name - combined_tags.append((tag_name, p)) - elif category == "Character" and p >= args.character_threshold: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += caption_separator + tag_name - if args.character_tags_first: # we separate character tags - character_tags.append((tag_name, p)) - else: - combined_tags.append((tag_name, p)) - elif ( - (category == "Copyright" and p >= args.copyright_threshold) - or (category == "Meta" and p >= args.meta_threshold) - or (category == "Model" and p >= args.model_threshold) - ): - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - other_tag_text += f"{caption_separator}{tag_name} ({category})" + tag_name = tag_id_to_tag_mapping[i] + category = tag_id_to_category_mapping[i] + if tag_name in undesired_tags: + continue + + if category == "Rating": + if p > rating_max_prob: + rating_max_prob = p + rating_tag = tag_name + rating_tag_text = tag_name + continue + elif category == "Quality": + if p > quality_max_prob: + quality_max_prob = p + quality_tag = tag_name + if args.use_quality_tags or args.use_quality_tags_as_last_tag: + other_tag_text += caption_separator + tag_name + continue + + if category == "General" and p >= args.general_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += caption_separator + tag_name + combined_tags.append((tag_name, p)) + elif category == "Character" and p >= args.character_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # we separate character tags + character_tags.append((tag_name, p)) + else: combined_tags.append((tag_name, p)) + elif ( + (category == "Copyright" and p >= args.copyright_threshold) + or (category == "Meta" and p >= args.meta_threshold) + or (category == "Model" and p >= args.model_threshold) + ): + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + other_tag_text += f"{caption_separator}{tag_name} ({category})" + combined_tags.append((tag_name, p)) # sort by probability combined_tags.sort(key=lambda x: x[1], reverse=True) if character_tags: - print(character_tags) character_tags.sort(key=lambda x: x[1], reverse=True) combined_tags = character_tags + combined_tags combined_tags = [t[0] for t in combined_tags] # remove probability @@ -486,55 +511,79 @@ def main(args): # Create new tag_text tag_text = caption_separator.join(existing_tags + new_tags) - with open(caption_file, "wt", encoding="utf-8") as f: - f.write(tag_text + "\n") - if args.debug: - logger.info("") - logger.info(f"{image_path}:") - logger.info(f"\tRating tags: {rating_tag_text}") - logger.info(f"\tCharacter tags: {character_tag_text}") - logger.info(f"\tGeneral tags: {general_tag_text}") - if other_tag_text: - logger.info(f"\tOther tags: {other_tag_text}") + if not args.output_path: + with open(caption_file, "wt", encoding="utf-8") as f: + f.write(tag_text + "\n") + else: + entry = {"tags": tag_text, "image_size": list(image_size)} + result[image_path] = entry + + if args.debug: + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tRating tags: {rating_tag_text}") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") + if other_tag_text: + logger.info(f"\tOther tags: {other_tag_text}") + + return result # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingPrepDataset(image_paths) + dataset = ImageLoadingPrepDataset(image_paths, args.batch_size) data = torch.utils.data.DataLoader( dataset, - batch_size=args.batch_size, + batch_size=1, shuffle=False, num_workers=args.max_data_loader_n_workers, - collate_fn=collate_fn_remove_corrupted, + collate_fn=collate_fn_no_op, drop_last=False, ) else: - data = [[(None, ip)] for ip in image_paths] + # data = [[(ip, None, None)] for ip in image_paths] + data = [[]] + for ip in image_paths: + if len(data[-1]) >= args.batch_size: + data.append([]) + data[-1].append((ip, None, None)) - b_imgs = [] + results = {} for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue + if data_entry is None or len(data_entry) == 0: + continue - image, image_path = data - if image is None: - try: - image = Image.open(image_path) - image = preprocess_image(image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - b_imgs.append((image_path, image)) + if data_entry[0][1] is None: + # No preloaded image, need to load + images = [] + image_sizes = [] + for image_path, _, _ in data_entry: + image = Image.open(image_path) + image_size = image.size + image = preprocess_image(image) + images.append(image) + image_sizes.append(image_size) + b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes) + else: + b_imgs = data_entry[0] - if len(b_imgs) >= args.batch_size: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) - b_imgs.clear() + r = run_batch(b_imgs) + if args.output_path and r is not None: + results.update(r) - if len(b_imgs) > 0: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) + if args.output_path: + if args.output_path.endswith(".jsonl"): + # optional JSONL metadata + with open(args.output_path, "wt", encoding="utf-8") as f: + for image_path, entry in results.items(): + f.write( + json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n" + ) + else: + # standard JSON metadata + with open(args.output_path, "wt", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=4) + logger.info(f"captions saved to {args.output_path}") if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) @@ -572,6 +621,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--output_path", + type=str, + default=None, + help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパス(json形式)。このオプションが設定されている場合、キャプションはこのファイルに保存されます", + ) parser.add_argument( "--caption_extention", type=str,