From c316c63dffa1c2ea079d7556f691e2613c6b6f00 Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Wed, 12 Apr 2023 05:41:28 +0700 Subject: [PATCH] fix: bring positional args back, add recursive to blip etc --- finetune/make_captions.py | 7 +++++-- finetune/make_captions_by_git.py | 7 +++++-- finetune/prepare_buckets_latents.py | 29 +++++++++++++++++++-------- finetune/tag_images_by_wd14_tagger.py | 8 ++++---- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index e690349a..1c8df2f4 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -4,6 +4,7 @@ import os import json import random +from pathlib import Path from PIL import Image from tqdm import tqdm import numpy as np @@ -72,7 +73,8 @@ def main(args): os.chdir('finetune') print(f"load images from {args.train_data_dir}") - image_paths = train_util.glob_images(args.train_data_dir) + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") print(f"loading BLIP caption: {args.caption_weights}") @@ -152,7 +154,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument("--debug", action="store_true", help="debug mode") - + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") + return parser diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 06af5598..58fa4ccb 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -2,6 +2,7 @@ import argparse import os import re +from pathlib import Path from PIL import Image from tqdm import tqdm import torch @@ -65,7 +66,8 @@ def main(args): GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch print(f"load images from {args.train_data_dir}") - image_paths = train_util.glob_images(args.train_data_dir) + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") # できればcacheに依存せず明示的にダウンロードしたい @@ -140,7 +142,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--remove_words", action="store_true", help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") parser.add_argument("--debug", action="store_true", help="debug mode") - + parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively") + return parser diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 8d9a38ab..b9c0fa50 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -2,6 +2,8 @@ import argparse import os import json +from pathlib import Path +from typing import List from tqdm import tqdm import numpy as np from PIL import Image @@ -41,14 +43,22 @@ def get_latents(vae, images, weight_dtype): return latents -def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): +def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): if is_full_path: base_name = os.path.splitext(os.path.basename(image_key))[0] + relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) else: base_name = image_key + relative_path = "" + if flip: base_name += '_flip' - return os.path.join(data_dir, base_name) + + if recursive and relative_path: + return os.path.join(data_dir, relative_path, base_name) + else: + return os.path.join(data_dir, base_name) + def main(args): @@ -56,7 +66,8 @@ def main(args): if args.bucket_reso_steps % 8 > 0: print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") - image_paths = train_util.glob_images(args.train_data_dir) + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] print(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): @@ -99,7 +110,7 @@ def main(args): f"latent shape {latents.shape}, {bucket[0][1].shape}" for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) np.savez(npz_file_name, latent) # flip @@ -107,12 +118,12 @@ def main(args): latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない for (image_key, _), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) np.savez(npz_file_name, latent) else: # remove existing flipped npz for image_key, _ in bucket: - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" if os.path.isfile(npz_file_name): print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") os.remove(npz_file_name) @@ -169,9 +180,9 @@ def main(args): # 既に存在するファイルがあればshapeを確認して同じならskipする if args.skip_existing: - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] if args.flip_aug: - npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") + npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz") found = True for npz_file in npz_files: @@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser: help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") parser.add_argument("--skip_existing", action="store_true", help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)") + parser.add_argument("--recursive", action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") return parser diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index ac00baec..efcf3fc1 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -10,7 +10,7 @@ import numpy as np from tensorflow.keras.models import load_model from huggingface_hub import hf_hub_download import torch -import pathlib +from pathlib import Path import library.train_util as train_util @@ -103,8 +103,8 @@ def main(args): # 画像を読み込む - train_data_dir = pathlib.Path(args.train_data_dir) - image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive) + train_data_dir_path = Path(args.train_data_dir) + image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") tag_freq = {} @@ -205,7 +205,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",