diff --git a/finetune/make_captions.py b/finetune/make_captions.py index e690349a..1c8df2f4 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -4,6 +4,7 @@ import os import json import random +from pathlib import Path from PIL import Image from tqdm import tqdm import numpy as np @@ -72,7 +73,8 @@ def main(args): os.chdir('finetune') print(f"load images from {args.train_data_dir}") - image_paths = train_util.glob_images(args.train_data_dir) + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") print(f"loading BLIP caption: {args.caption_weights}") @@ -152,7 +154,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument("--debug", action="store_true", help="debug mode") - + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") + return parser diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 06af5598..58fa4ccb 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -2,6 +2,7 @@ import argparse import os import re +from pathlib import Path from PIL import Image from tqdm import tqdm import torch @@ -65,7 +66,8 @@ def main(args): GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch print(f"load images from {args.train_data_dir}") - image_paths = train_util.glob_images(args.train_data_dir) + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい @@ -140,7 +142,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--remove_words", action="store_true", help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") parser.add_argument("--debug", action="store_true", help="debug mode") - + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") + return parser diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 8d9a38ab..b9c0fa50 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -2,6 +2,8 @@ import argparse import os import json +from pathlib import Path +from typing import List from tqdm import tqdm import numpy as np from PIL import Image @@ -41,14 +43,22 @@ def get_latents(vae, images, weight_dtype): return latents -def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): +def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): if is_full_path: base_name = os.path.splitext(os.path.basename(image_key))[0] + relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) else: base_name = image_key + relative_path = "" + if flip: base_name += '_flip' - return os.path.join(data_dir, base_name) + + if recursive and relative_path: + return os.path.join(data_dir, relative_path, base_name) + else: + return os.path.join(data_dir, base_name) + def main(args): @@ -56,7 +66,8 @@ def main(args): if args.bucket_reso_steps % 8 > 0: print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") - image_paths = train_util.glob_images(args.train_data_dir) + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] print(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): @@ -99,7 +110,7 @@ def main(args): f"latent shape {latents.shape}, {bucket[0][1].shape}" for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) np.savez(npz_file_name, latent) # flip @@ -107,12 +118,12 @@ def main(args): latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) np.savez(npz_file_name, latent) else: # remove existing flipped npz for image_key, _ in bucket: - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" if os.path.isfile(npz_file_name): print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") os.remove(npz_file_name) @@ -169,9 +180,9 @@ def main(args): # 既に存在するファイルがあればshapeを確認して同じならskipする if args.skip_existing: - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] if args.flip_aug: - npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") + npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz") found = True for npz_file in npz_files: @@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser: help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") parser.add_argument("--skip_existing", action="store_true", help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)") + parser.add_argument("--recursive", action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") return parser diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 2286115e..efcf3fc1 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -10,6 +10,7 @@ import numpy as np from tensorflow.keras.models import load_model from huggingface_hub import hf_hub_download import torch +from pathlib import Path import library.train_util as train_util @@ -23,184 +24,212 @@ SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] - def preprocess_image(image): - image = np.array(image) - image = image[:, :, ::-1] # RGB->BGR + image = np.array(image) + image = image[:, :, ::-1] # RGB->BGR - # pad to square - size = max(image.shape[0:2]) - pad_x = size - image.shape[1] - pad_y = size - image.shape[0] - pad_l = pad_x // 2 - pad_t = pad_y // 2 - image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + # pad to square + size = max(image.shape[0:2]) + pad_x = size - image.shape[1] + pad_y = size - image.shape[0] + pad_l = pad_x // 2 + pad_t = pad_y // 2 + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) - - image = image.astype(np.float32) - return image + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + image = image.astype(np.float32) + return image class ImageLoadingPrepDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] - - try: - image = Image.open(img_path).convert("RGB") - image = preprocess_image(image) - tensor = torch.tensor(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None - - return (tensor, img_path) + def __getitem__(self, idx): + img_path = str(self.images[idx]) + try: + image = Image.open(img_path).convert("RGB") + image = preprocess_image(image) + tensor = torch.tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + return (tensor, img_path) + def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) - return batch - + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch def main(args): - # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする - # depreacatedの警告が出るけどなくなったらその時 - # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 - if not os.path.exists(args.model_dir) or args.force_download: - print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - for file in FILES: - hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) - for file in SUB_DIR_FILES: - hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( - args.model_dir, SUB_DIR), force_download=True, force_filename=file) - else: - print("using existing wd14 tagger model") + # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする + # depreacatedの警告が出るけどなくなったらその時 + # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(args.model_dir) or args.force_download: + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + for file in FILES: + hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) + for file in SUB_DIR_FILES: + hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( + args.model_dir, SUB_DIR), force_download=True, force_filename=file) + else: + print("using existing wd14 tagger model") - # 画像を読み込む - image_paths = train_util.glob_images(args.train_data_dir) - print(f"found {len(image_paths)} images.") + # 画像を読み込む + model = load_model(args.model_dir) - print("loading model and labels") - model = load_model(args.model_dir) + # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") + # 依存ライブラリを増やしたくないので自力で読むよ - # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: - reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] - assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" + with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + l = [row for row in reader] + header = l[0] # tag_id,name,category,count + rows = l[1:] + assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" - tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ + general_tags = [row[1] for row in rows[1:] if row[2] == '0'] + character_tags = [row[1] for row in rows[1:] if row[2] == '4'] - # 推論する - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) + # 画像を読み込む + + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) + print(f"found {len(image_paths)} images.") - probs = model(imgs, training=False) - probs = probs.numpy() + tag_freq = {} - for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] + undesired_tags = set(args.undesired_tags.split(',')) - # それ以降はタグなのでconfidenceがthresholdより高いものを追加する - # Everything else is tags: pick any where prediction confidence > threshold - tag_text = "" - for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで - if p >= args.thresh and i < len(tags): - tag_text += ", " + tags[i] + def run_batch(path_imgs): + imgs = np.array([im for _, im in path_imgs]) - if len(tag_text) > 0: - tag_text = tag_text[2:] # 最初の ", " を消す + probs = model(imgs, training=False) + probs = probs.numpy() - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(tag_text + '\n') - if args.debug: - print(image_path, tag_text) + for (image_path, _), prob in zip(path_imgs, probs): + # 最初の4つはratingなので無視する + # # First 4 labels are actually ratings: pick one with argmax + # ratings_names = label_names[:4] + # rating_index = ratings_names["probs"].argmax() + # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] - # 読み込みの高速化のためにDataLoaderを使うオプション - if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingPrepDataset(image_paths) - data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) - else: - data = [[(None, ip)] for ip in image_paths] + # それ以降はタグなのでconfidenceがthresholdより高いものを追加する + # Everything else is tags: pick any where prediction confidence > threshold + combined_tags = [] + general_tag_text = "" + character_tag_text = "" + for i, p in enumerate(prob[4:]): + if i < len(general_tags) and p >= args.general_threshold: + tag_name = general_tags[i].replace('_', ' ') if args.remove_underscore else general_tags[i] + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += ", " + tag_name + combined_tags.append(tag_name) + elif i >= len(general_tags) and p >= args.character_threshold: + tag_name = character_tags[i - len(general_tags)].replace('_', ' ') if args.remove_underscore else character_tags[i - len(general_tags)] + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += ", " + tag_name + combined_tags.append(tag_name) - b_imgs = [] - for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue + if len(general_tag_text) > 0: + general_tag_text = general_tag_text[2:] - image, image_path = data - if image is not None: - image = image.detach().numpy() - else: - try: - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - image = preprocess_image(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - b_imgs.append((image_path, image)) + if len(character_tag_text) > 0: + character_tag_text = character_tag_text[2:] - if len(b_imgs) >= args.batch_size: + tag_text = ', '.join(combined_tags) + + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(tag_text + '\n') + if args.debug: + print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingPrepDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is not None: + image = image.detach().numpy() + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + image = preprocess_image(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + b_imgs.append((image_path, image)) + + if len(b_imgs) >= args.batch_size: + b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) + if args.frequency_tags: + sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) + print("\nTag frequencies:") + for tag, freq in sorted_tags: + print(f"{tag}: {freq}") - print("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, - help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") - parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", - help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") - parser.add_argument("--force_download", action='store_true', - help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--max_data_loader_n_workers", type=int, default=None, - help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode") - - return parser + print("done!") if __name__ == '__main__': - parser = setup_parser() + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, + help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") + parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", + help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") + parser.add_argument("--force_download", action='store_true', + help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") + parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--general_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for general category") + parser.add_argument("--character_threshold", type=float, default=0.35, help="threshold of confidence to add a tag for character category") + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") + parser.add_argument("--remove_underscore", action="store_true", help="replace underscores with spaces in the output tags") + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--undesired_tags", type=str, default="", help="comma-separated list of undesired tags to remove from the output") + parser.add_argument('--frequency_tags', action='store_true', help='Show frequency of tags for images') - args = parser.parse_args() + args = parser.parse_args() - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention + # スペルミスしていたオプションを復元する + if args.caption_extention is not None: + args.caption_extension = args.caption_extention - main(args) + main(args) diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index 7c7cc1c5..c33ca202 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -12,12 +12,12 @@ def convert(args): # 引数を確認する load_dtype = torch.float16 if args.fp16 else None - save_dtype = None - if args.fp16: + save_dtype = None + if args.fp16 or args.save_precision_as == "fp16": save_dtype = torch.float16 - elif args.bf16: + elif args.bf16 or args.save_precision_as == "bf16": save_dtype = torch.bfloat16 - elif args.float: + elif args.float or args.save_precision_as == "float": save_dtype = torch.float is_load_ckpt = os.path.isfile(args.model_to_load) @@ -72,6 +72,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') parser.add_argument("--float", action='store_true', help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') + parser.add_argument("--save_precision_as", type=str, default="no", choices=["fp16", "bf16", "float"], + help="save precision") parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') parser.add_argument("--global_step", type=int, default=0, help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')