From 4735b213183fd3a101dee310f1c7376e9746a411 Mon Sep 17 00:00:00 2001 From: breakcore2 Date: Fri, 6 Jan 2023 04:07:04 -0800 Subject: [PATCH 01/14] add .bmp support for wd14 tagger --- finetune/tag_images_by_wd14_tagger.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index c5767894..c543332e 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -37,7 +37,9 @@ def main(args): # 画像を読み込む image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + \ + glob.glob(os.path.join(args.train_data_dir, "*.bmp")) print(f"found {len(image_paths)} images.") print("loading model and labels") From 2e8a3d20dd672d1036dc6bd8a059d480818cf4ae Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 23 Jan 2023 17:43:03 -0800 Subject: [PATCH 02/14] Add tag frequency metadata --- library/train_util.py | 10 ++++++++++ train_network.py | 1 + 2 files changed, 11 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc1..ba7a5739 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -85,6 +85,7 @@ class BaseDataset(torch.utils.data.Dataset): self.enable_bucket = False self.min_bucket_reso = None self.max_bucket_reso = None + self.tag_frequency = {} self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -520,6 +521,15 @@ class DreamBoothDataset(BaseDataset): cap_for_img = read_caption(img_path) captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {}) + self.tag_frequency[os.path.basename(dir)] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + if tag and not tag.isspace(): + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + return n_repeats, img_paths, captions print("prepare train images.") diff --git a/train_network.py b/train_network.py index d60ae9a0..f3ff3c51 100644 --- a/train_network.py +++ b/train_network.py @@ -264,6 +264,7 @@ def train(args): "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_training_comment": args.training_comment # will not be updated after training } From 2ce9ad235cdb3757b517f486c63bc31fee7bc30f Mon Sep 17 00:00:00 2001 From: breakcore2 Date: Thu, 26 Jan 2023 01:01:38 -0800 Subject: [PATCH 03/14] add recursive structure merge dd tags and convert to pathlib --- finetune/merge_dd_tags_to_metadata.py | 40 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 8101ecd3..913fa90a 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -2,25 +2,33 @@ # (c) 2022 Kohya S. @kohya_ss import argparse -import glob -import os import json +from pathlib import Path from tqdm import tqdm def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + + image_paths = None + train_data_dir_path = Path(args.train_data_dir) + if args.recursive_data_dir: + image_paths = list(train_data_dir_path.rglob('*.jpg')) + \ + list(train_data_dir_path.rglob('*.png')) + \ + list(train_data_dir_path.rglob('*.webp')) + else: + image_paths = list(train_data_dir_path.glob('*.jpg')) + \ + list(train_data_dir_path.glob('*.png')) + \ + list(train_data_dir_path.glob('*.webp')) + print(f"found {len(image_paths)} images.") - if args.in_json is None and os.path.isfile(args.out_json): + if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) + metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") else: print("new metadata will be created / 新しいメタデータファイルが作成されます") @@ -28,22 +36,21 @@ def main(args): print("merge tags to metadata json.") for image_path in tqdm(image_paths): - tags_path = os.path.splitext(image_path)[0] + '.txt' - with open(tags_path, "rt", encoding='utf-8') as f: - tags = f.readlines()[0].strip() + tags_path = image_path.with_suffix('.txt') + tags = tags_path.read_text(encoding='utf-8').strip() - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - if image_key not in metadata: - metadata[image_key] = {} + image_key = image_path if args.full_path else image_path.stem + if str(image_key) not in metadata: + metadata[str(image_key)] = {} - metadata[image_key]['tags'] = tags + metadata[str(image_key)]['tags'] = tags if args.debug: print(image_key, tags) # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding='utf-8') as f: - json.dump(metadata, f, indent=2) + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') + print("done!") @@ -54,6 +61,7 @@ if __name__ == '__main__': parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") parser.add_argument("--full_path", action="store_true", help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") + parser.add_argument("--recursive_data_dir", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") parser.add_argument("--debug", action="store_true", help="debug mode, print tags") args = parser.parse_args() From 64d5ceda71326fd95928fc848fd39f6d18061be7 Mon Sep 17 00:00:00 2001 From: breakcore2 Date: Thu, 26 Jan 2023 01:06:33 -0800 Subject: [PATCH 04/14] simplify arg to --recursive --- finetune/merge_dd_tags_to_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 913fa90a..849d37d9 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -12,7 +12,7 @@ def main(args): image_paths = None train_data_dir_path = Path(args.train_data_dir) - if args.recursive_data_dir: + if args.recursive: image_paths = list(train_data_dir_path.rglob('*.jpg')) + \ list(train_data_dir_path.rglob('*.png')) + \ list(train_data_dir_path.rglob('*.webp')) @@ -61,7 +61,7 @@ if __name__ == '__main__': parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") parser.add_argument("--full_path", action="store_true", help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive_data_dir", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") + parser.add_argument("--recursive", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") parser.add_argument("--debug", action="store_true", help="debug mode, print tags") args = parser.parse_args() From c20745b6e8344b0db73536c4f9b30382eb0fc263 Mon Sep 17 00:00:00 2001 From: unknown Date: Sun, 29 Jan 2023 22:30:45 +0900 Subject: [PATCH 05/14] fix: #53 --- finetune/make_captions.py | 2 +- finetune/merge_captions_to_metadata.py | 5 +++-- finetune/merge_dd_tags_to_metadata.py | 2 +- finetune/prepare_buckets_latents.py | 2 +- finetune/tag_images_by_wd14_tagger.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index b02420bd..495450aa 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -31,7 +31,7 @@ def main(args): os.chdir('finetune') print(f"load images from {args.train_data_dir}") - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 2da6356f..703f4f9d 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -10,7 +10,7 @@ from tqdm import tqdm def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") @@ -30,7 +30,8 @@ def main(args): for image_path in tqdm(image_paths): caption_path = os.path.splitext(image_path)[0] + args.caption_extension with open(caption_path, "rt", encoding='utf-8') as f: - caption = f.readlines()[0].strip() + lines = f.readlines() + caption = lines[0].strip() if len(lines) > 0 else "" image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] if image_key not in metadata: diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 8101ecd3..896c99ae 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -10,7 +10,7 @@ from tqdm import tqdm def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 00f847a1..87236c43 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype): def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index c5767894..60a9a890 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -36,7 +36,7 @@ def main(args): args.model_dir, SUB_DIR), force_download=True, force_filename=file) # 画像を読み込む - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") From 57d8483eaf5be6081c1f167f572291f449c6ba0c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 08:45:33 +0900 Subject: [PATCH 06/14] add GIT captioning, refactoring, DataLoader --- .gitignore | 3 +- finetune/make_captions.py | 101 +++++++++++++----- finetune/make_captions_by_git.py | 136 +++++++++++++++++++++++++ finetune/merge_captions_to_metadata.py | 36 ++++--- finetune/merge_dd_tags_to_metadata.py | 41 ++++---- finetune/prepare_buckets_latents.py | 125 +++++++++++++++++------ finetune/tag_images_by_wd14_tagger.py | 128 ++++++++++++++++------- library/train_util.py | 51 ++++++++-- requirements.txt | 2 +- 9 files changed, 479 insertions(+), 144 deletions(-) create mode 100644 finetune/make_captions_by_git.py diff --git a/.gitignore b/.gitignore index 7c088d5c..0904a2a4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ wd14_tagger_model venv *.egg-info -build \ No newline at end of file +build +.vscode \ No newline at end of file diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 495450aa..a2a35b39 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -11,18 +11,59 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from blip.blip import blip_decoder -# from Salesforce_BLIP.models.blip import blip_decoder +import library.train_util as train_util DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +IMAGE_SIZE = 384 + +# 正方形でいいのか? という気がするがソースがそうなので +IMAGE_TRANSFORM = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) +]) + +# 共通化したいが微妙に処理が異なる…… +class ImageLoadingTransformDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor = IMAGE_TRANSFORM(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + + return (tensor, img_path) + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + def main(args): # fix the seed for reproducibility - seed = args.seed # + utils.get_rank() + seed = args.seed # + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) - + if not os.path.exists("blip"): args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path @@ -31,24 +72,15 @@ def main(args): os.chdir('finetune') print(f"load images from {args.train_data_dir}") - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + image_paths = train_util.glob_images(args.train_data_dir) print(f"found {len(image_paths)} images.") print(f"loading BLIP caption: {args.caption_weights}") - image_size = 384 - model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json") + model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) print("BLIP loaded") - # 正方形でいいのか? という気がするがソースがそうなので - transform = transforms.Compose([ - transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ]) - # captioningする def run_batch(path_imgs): imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) @@ -66,18 +98,35 @@ def main(args): if args.debug: print(image_path, caption) - b_imgs = [] - for image_path in tqdm(image_paths, smoothing=0.0): - raw_image = Image.open(image_path) - if raw_image.mode != "RGB": - print(f"convert image mode {raw_image.mode} to RGB: {image_path}") - raw_image = raw_image.convert("RGB") + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingTransformDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] - image = transform(raw_image) - b_imgs.append((image_path, image)) - if len(b_imgs) >= args.batch_size: - run_batch(b_imgs) - b_imgs.clear() + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + img_tensor, image_path = data + if img_tensor is None: + try: + raw_image = Image.open(image_path) + if raw_image.mode != 'RGB': + raw_image = raw_image.convert("RGB") + img_tensor = IMAGE_TRANSFORM(raw_image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, img_tensor)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() if len(b_imgs) > 0: run_batch(b_imgs) @@ -95,6 +144,8 @@ if __name__ == '__main__': parser.add_argument("--beam_search", action="store_true", help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py new file mode 100644 index 00000000..daaef9e7 --- /dev/null +++ b/finetune/make_captions_by_git.py @@ -0,0 +1,136 @@ +import argparse +import os +import re + +from PIL import Image +from tqdm import tqdm +import torch +from transformers import AutoProcessor, AutoModelForCausalLM +from transformers.generation.utils import GenerationMixin + +import library.train_util as train_util + + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +PATTERN_REPLACE = [re.compile(r'with the (words?|letters?) (" ?[^"]*"|\w+)( on (the)? ?\w+)?'), + re.compile(r'that says (" ?[^"]*"|\w+)')] + + +# 誤検知しまくりの with the word xxxx を消す +def remove_words(captions, debug): + removed_caps = [] + for caption in captions: + cap = caption + for pat in PATTERN_REPLACE: + cap = pat.sub("", caption) + if debug and cap != caption: + print(caption) + print(cap) + removed_caps.append(cap) + return removed_caps + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + +def main(args): + # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 + org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation + curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように + + # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す + # ここより上で置き換えようとするとすごく大変 + def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): + input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) + if input_ids.size()[0] != curr_batch_size[0]: + input_ids = input_ids.repeat(curr_batch_size[0], 1) + return input_ids + GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + + print(f"load images from {args.train_data_dir}") + image_paths = train_util.glob_images(args.train_data_dir) + print(f"found {len(image_paths)} images.") + + # できればcacheに依存せず明示的にダウンロードしたい + print(f"loading GIT: {args.model_id}") + git_processor = AutoProcessor.from_pretrained(args.model_id) + git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) + print("GIT loaded") + + # captioningする + def run_batch(path_imgs): + imgs = [im for _, im in path_imgs] + + curr_batch_size[0] = len(path_imgs) + inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 + generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) + captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) + + if args.remove_words: + captions = remove_words(captions, args.debug) + + for (image_path, _), caption in zip(path_imgs, captions): + with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: + f.write(caption + "\n") + if args.debug: + print(image_path, caption) + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + b_imgs = [] + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue + + image, image_path = data + if image is None: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + + b_imgs.append((image_path, image)) + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() + + if len(b_imgs) > 0: + run_batch(b_imgs) + + print("done!") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") + parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps", + help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID") + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") + parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") + parser.add_argument("--remove_words", action="store_true", + help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") + parser.add_argument("--debug", action="store_true", help="debug mode") + + args = parser.parse_args() + main(args) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index 703f4f9d..cbc5033f 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -1,26 +1,24 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - import argparse -import glob -import os import json - +from pathlib import Path +from typing import List from tqdm import tqdm +import library.train_util as train_util def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") - if args.in_json is None and os.path.isfile(args.out_json): + if args.in_json is None and Path(args.out_json).is_file(): args.in_json = args.out_json if args.in_json is not None: print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) + metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") else: print("new metadata will be created / 新しいメタデータファイルが作成されます") @@ -28,12 +26,10 @@ def main(args): print("merge caption texts to metadata json.") for image_path in tqdm(image_paths): - caption_path = os.path.splitext(image_path)[0] + args.caption_extension - with open(caption_path, "rt", encoding='utf-8') as f: - lines = f.readlines() - caption = lines[0].strip() if len(lines) > 0 else "" + caption_path = image_path.with_suffix(args.caption_extension) + caption = caption_path.read_text(encoding='utf-8').strip() - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} @@ -43,8 +39,7 @@ def main(args): # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") - with open(args.out_json, "wt", encoding='utf-8') as f: - json.dump(metadata, f, indent=2) + Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') print("done!") @@ -52,12 +47,15 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") + parser.add_argument("--in_json", type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") parser.add_argument("--caption_extention", type=str, default=None, help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") parser.add_argument("--full_path", action="store_true", help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") + parser.add_argument("--recursive", action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") parser.add_argument("--debug", action="store_true", help="debug mode") args = parser.parse_args() diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 16267cd8..4285feb0 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -1,27 +1,16 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - import argparse import json from pathlib import Path - +from typing import List from tqdm import tqdm +import library.train_util as train_util def main(args): - image_paths = None - train_data_dir_path = Path(args.train_data_dir) - if args.recursive: - image_paths = list(train_data_dir_path.rglob('*.jpg')) + \ - list(train_data_dir_path.rglob('*.jpeg')) + \ - list(train_data_dir_path.rglob('*.png')) + \ - list(train_data_dir_path.rglob('*.webp')) - else: - image_paths = list(train_data_dir_path.glob('*.jpg')) + \ - list(train_data_dir_path.glob('*.jpeg')) + \ - list(train_data_dir_path.glob('*.png')) + \ - list(train_data_dir_path.glob('*.webp')) + assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" + train_data_dir_path = Path(args.train_data_dir) + image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) print(f"found {len(image_paths)} images.") if args.in_json is None and Path(args.out_json).is_file(): @@ -37,21 +26,21 @@ def main(args): print("merge tags to metadata json.") for image_path in tqdm(image_paths): - tags_path = image_path.with_suffix('.txt') + tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() - image_key = image_path if args.full_path else image_path.stem - if str(image_key) not in metadata: - metadata[str(image_key)] = {} + image_key = str(image_path) if args.full_path else image_path.stem + if image_key not in metadata: + metadata[image_key] = {} - metadata[str(image_key)]['tags'] = tags + metadata[image_key]['tags'] = tags if args.debug: print(image_key, tags) # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') - + print("done!") @@ -59,10 +48,14 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") - parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") + parser.add_argument("--in_json", type=str, + help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") parser.add_argument("--full_path", action="store_true", help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") - parser.add_argument("--recursive", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") + parser.add_argument("--recursive", action="store_true", + help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") + parser.add_argument("--caption_extension", type=str, default=".txt", + help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") parser.add_argument("--debug", action="store_true", help="debug mode, print tags") args = parser.parse_args() diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 87236c43..537626d8 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -1,20 +1,16 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - import argparse -import glob import os import json from tqdm import tqdm import numpy as np -from diffusers import AutoencoderKL from PIL import Image import cv2 import torch from torchvision import transforms import library.model_util as model_util +import library.train_util as train_util DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose( ) +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + def get_latents(vae, images, weight_dtype): img_tensors = [IMAGE_TRANSFORMS(image) for image in images] img_tensors = torch.stack(img_tensors) @@ -35,9 +41,18 @@ def get_latents(vae, images, weight_dtype): return latents +def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): + if is_full_path: + base_name = os.path.splitext(os.path.basename(image_key))[0] + else: + base_name = image_key + if flip: + base_name += '_flip' + return os.path.join(data_dir, base_name) + + def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + image_paths = train_util.glob_images(args.train_data_dir) print(f"found {len(image_paths)} images.") if os.path.exists(args.in_json): @@ -48,6 +63,25 @@ def main(args): print(f"no metadata / メタデータファイルがありません: {args.in_json}") return + # 既に存在するファイルをfilterする + if args.skip_existing: + filtered = [] + for image_path in image_paths: + image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] + + npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz" + if os.path.exists(npz_file_name_flip): + if not args.flip_aug: + continue + + npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" + if os.path.exists(npz_file_name_flip): + continue + + filtered.apppend(image_path) + print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}") + image_paths = filtered + weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -70,15 +104,55 @@ def main(args): buckets_imgs = [[] for _ in range(len(bucket_resos))] bucket_counts = [0 for _ in range(len(bucket_resos))] img_ar_errors = [] - for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)): + + def process_batch(is_last): + for j in range(len(buckets_imgs)): + bucket = buckets_imgs[j] + if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: + latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) + + for (image_key, _, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + np.savez(npz_file_name, latent) + + # flip + if args.flip_aug: + latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない + + for (image_key, _, _), latent in zip(bucket, latents): + npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + np.savez(npz_file_name, latent) + + bucket.clear() + + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = train_util.ImageLoadingDataset(image_paths) + data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + + for data_entry in tqdm(data, smoothing=0.0): + if data_entry[0] is None: + continue + + img_tensor, image_path = data_entry[0] + if img_tensor is not None: + image = transforms.functional.to_pil_image(img_tensor) + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] if image_key not in metadata: metadata[image_key] = {} - image = Image.open(image_path) - if image.mode != 'RGB': - image = image.convert("RGB") - aspect_ratio = image.width / image.height ar_errors = bucket_aspect_ratios - aspect_ratio bucket_id = np.abs(ar_errors).argmin() @@ -123,25 +197,10 @@ def main(args): metadata[image_key]['train_resolution'] = reso # バッチを推論するか判定して推論する - is_last = i == len(image_paths) - 1 - for j in range(len(buckets_imgs)): - bucket = buckets_imgs[j] - if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) + process_batch(False) - for (image_key, reso, _), latent in zip(bucket, latents): - npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key - np.savez(os.path.join(args.train_data_dir, npz_file_name), latent) - - # flip - if args.flip_aug: - latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない - - for (image_key, reso, _), latent in zip(bucket, latents): - npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key - np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent) - - bucket.clear() + # 残りを処理する + process_batch(True) for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)): print(f"bucket {i} {reso}: {count}") @@ -162,8 +221,10 @@ if __name__ == '__main__': parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)') parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") parser.add_argument("--max_resolution", type=str, default="512,512", help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") @@ -174,6 +235,8 @@ if __name__ == '__main__': help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") parser.add_argument("--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") + parser.add_argument("--skip_existing", action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)") args = parser.parse_args() main(args) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index d7166114..609b8c50 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,6 +1,3 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - import argparse import csv import glob @@ -12,35 +9,87 @@ from tqdm import tqdm import numpy as np from tensorflow.keras.models import load_model from huggingface_hub import hf_hub_download +import torch + +import library.train_util as train_util # from wd14 tagger IMAGE_SIZE = 448 -WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger' +# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 +DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] +def preprocess_image(image): + image = np.array(image) + image = image[:, :, ::-1] # RGB->BGR + + # pad to square + size = max(image.shape[0:2]) + pad_x = size - image.shape[1] + pad_y = size - image.shape[0] + pad_l = pad_x // 2 + pad_t = pad_y // 2 + image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + + interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + + image = image.astype(np.float32) + return image + + +class ImageLoadingPrepDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + image = preprocess_image(image) + tensor = torch.tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + + return (tensor, img_path) + + +def collate_fn_remove_corrupted(batch): + """Collate function that allows to remove corrupted examples in the + dataloader. It expects that the dataloader returns 'None' when that occurs. + The 'None's in the batch are removed. + """ + # Filter out all the Nones (corrupted examples) + batch = list(filter(lambda x: x is not None, batch)) + return batch + + def main(args): # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: - print("downloading wd14 tagger model from hf_hub") + print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") for file in FILES: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( args.model_dir, SUB_DIR), force_download=True, force_filename=file) + else: + print("using existing wd14 tagger model") # 画像を読み込む - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.webp")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.bmp")) + image_paths = train_util.glob_images(args.train_data_dir) print(f"found {len(image_paths)} images.") print("loading model and labels") @@ -75,7 +124,7 @@ def main(args): # Everything else is tags: pick any where prediction confidence > threshold tag_text = "" for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで - if p >= args.thresh: + if p >= args.thresh and i < len(tags): tag_text += ", " + tags[i] if len(tag_text) > 0: @@ -86,34 +135,37 @@ def main(args): if args.debug: print(image_path, tag_text) + # 読み込みの高速化のためにDataLoaderを使うオプション + if args.max_data_loader_n_workers is not None: + dataset = ImageLoadingPrepDataset(image_paths) + data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + else: + data = [[(None, ip)] for ip in image_paths] + b_imgs = [] - for image_path in tqdm(image_paths, smoothing=0.0): - img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く - if img.mode != 'RGB': - img = img.convert("RGB") - img = np.array(img) - img = img[:, :, ::-1] # RGB->BGR + for data_entry in tqdm(data, smoothing=0.0): + for data in data_entry: + if data is None: + continue - # pad to square - size = max(img.shape[0:2]) - pad_x = size - img.shape[1] - pad_y = size - img.shape[0] - pad_l = pad_x // 2 - pad_t = pad_y // 2 - img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) + image, image_path = data + if image is not None: + image = image.detach().numpy() + else: + try: + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + image = preprocess_image(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") + continue + b_imgs.append((image_path, image)) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) - # cv2.imshow("img", img) - # cv2.waitKey() - # cv2.destroyAllWindows() - - img = img.astype(np.float32) - b_imgs.append((image_path, img)) - - if len(b_imgs) >= args.batch_size: - run_batch(b_imgs) - b_imgs.clear() + if len(b_imgs) >= args.batch_size: + run_batch(b_imgs) + b_imgs.clear() if len(b_imgs) > 0: run_batch(b_imgs) @@ -124,7 +176,7 @@ def main(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO, + parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") @@ -132,6 +184,8 @@ if __name__ == '__main__': help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_data_loader_n_workers", type=int, default=None, + help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)") parser.add_argument("--caption_extention", type=str, default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") diff --git a/library/train_util.py b/library/train_util.py index 0946c31d..459b81a1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -44,7 +44,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset -IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] class ImageInfo(): @@ -141,7 +141,7 @@ class BaseDataset(torch.utils.data.Dataset): if type(str_to) == list: caption = random.choice(str_to) else: - caption = str_to + caption = str_to else: caption = caption.replace(str_from, str_to) @@ -247,7 +247,6 @@ class BaseDataset(torch.utils.data.Dataset): mean_img_ar_error = np.mean(np.abs(img_ar_errors)) self.bucket_info["mean_img_ar_error"] = mean_img_ar_error print(f"mean ar error (without repeats): {mean_img_ar_error}") - # 参照用indexを作る self.buckets_indices: list(BucketBatchIndex) = [] @@ -766,15 +765,30 @@ def debug_dataset(train_dataset, show_input_ids=False): break -def glob_images(dir, base): +def glob_images(directory, base="*"): img_paths = [] for ext in IMAGE_EXTENSIONS: if base == '*': - img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext))) + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: - img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext)))) + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() return img_paths + +def glob_images_pathlib(dir_path, recursive): + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob('*' + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob('*' + ext)) + image_paths = list(set(image_paths)) # 重複を排除 + image_paths.sort() + return image_paths + # endregion @@ -1505,5 +1519,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) +# endregion + +# region 前処理用 + + +class ImageLoadingDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + + return (tensor_pil, img_path) + # endregion diff --git a/requirements.txt b/requirements.txt index 36f48a0f..709a8342 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate==0.15.0 -transformers==4.25.1 +transformers==4.26.0 ftfy albumentations opencv-python From 93134cdd1595339735e943fa7f22627e02c4ce68 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 21:03:42 +0900 Subject: [PATCH 07/14] Add tag freq for FinetuneDataset --- library/train_util.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 459b81a1..c1e54517 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -44,7 +44,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset -IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] +# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux? class ImageInfo(): @@ -116,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + if tag and not tag.isspace(): + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + def disable_token_padding(self): self.token_padding_disabled = True @@ -545,14 +556,7 @@ class DreamBoothDataset(BaseDataset): cap_for_img = read_caption(img_path) captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {}) - self.tag_frequency[os.path.basename(dir)] = frequency_for_dir - for caption in captions: - for tag in caption.split(","): - if tag and not tag.isspace(): - tag = tag.lower() - frequency = frequency_for_dir.get(tag, 0) - frequency_for_dir[tag] = frequency + 1 + self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録 return n_repeats, img_paths, captions @@ -562,10 +566,13 @@ class DreamBoothDataset(BaseDataset): for dir in train_dirs: n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) num_train_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) + self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} + print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -579,9 +586,11 @@ class DreamBoothDataset(BaseDataset): for dir in reg_dirs: n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) num_reg_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) + self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") @@ -626,6 +635,7 @@ class FineTuningDataset(BaseDataset): self.train_data_dir = train_data_dir self.batch_size = batch_size + tags_list = [] for image_key, img_md in metadata.items(): # path情報を作る if os.path.exists(image_key): @@ -642,6 +652,7 @@ class FineTuningDataset(BaseDataset): caption = tags elif tags is not None and len(tags) > 0: caption = caption + ', ' + tags + tags_list.append(tags) assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) @@ -655,7 +666,8 @@ class FineTuningDataset(BaseDataset): self.num_train_images = len(metadata) * dataset_repeats self.num_reg_images = 0 - self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} + self.set_tag_frequency(os.path.basename(json_file_name), tags_list) + self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} # check existence of all npz files if not self.color_aug: @@ -676,6 +688,8 @@ class FineTuningDataset(BaseDataset): print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") elif not npz_all: print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + if self.flip_aug: + print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") for image_info in self.image_data.values(): image_info.latents_npz = image_info.latents_npz_flipped = None @@ -772,8 +786,8 @@ def glob_images(directory, base="*"): img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - img_paths = list(set(img_paths)) # 重複を排除 - img_paths.sort() + # img_paths = list(set(img_paths)) # 重複を排除 + # img_paths.sort() return img_paths @@ -785,8 +799,8 @@ def glob_images_pathlib(dir_path, recursive): else: for ext in IMAGE_EXTENSIONS: image_paths += list(dir_path.glob('*' + ext)) - image_paths = list(set(image_paths)) # 重複を排除 - image_paths.sort() + # image_paths = list(set(image_paths)) # 重複を排除 + # image_paths.sort() return image_paths # endregion From 58a809eaff535f43e51a81a6014406afd65c033b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 21:04:03 +0900 Subject: [PATCH 08/14] Add comment --- gen_img_diffusers.py | 12 +++--- train_network.py | 92 ++++++++++++++++++++++---------------------- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 8b29e78e..25a5b2d9 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1845,12 +1845,12 @@ def main(args): text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - tokenizer = pipe.tokenizer - del pipe + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe # VAEを読み込む if args.vae is not None: diff --git a/train_network.py b/train_network.py index aebc4a40..88405221 100644 --- a/train_network.py +++ b/train_network.py @@ -1,3 +1,6 @@ +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from torch.optim import Optimizer +from typing import Optional, Union import importlib import argparse import gc @@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche # Which is a newer release of diffusers than currently packaged with sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -from typing import Optional, Union -from torch.optim import Optimizer -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION def get_scheduler_fix( name: Union[str, SchedulerType], @@ -52,53 +52,53 @@ def get_scheduler_fix( num_cycles: int = 1, power: float = 1.0, ): - """ - Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles - ) + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power - ) + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) def train(args): @@ -135,7 +135,7 @@ def train(args): train_util.debug_dataset(train_dataset) return if len(train_dataset) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") return # acceleratorを準備する @@ -224,7 +224,7 @@ def train(args): # lr schedulerを用意する # lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) From 73d612ff9cd14697e4d2645c05b79fd64880791f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 21:04:37 +0900 Subject: [PATCH 09/14] Add cleaning patterns --- finetune/make_captions_by_git.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index daaef9e7..ebc91920 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -13,17 +13,26 @@ import library.train_util as train_util DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -PATTERN_REPLACE = [re.compile(r'with the (words?|letters?) (" ?[^"]*"|\w+)( on (the)? ?\w+)?'), - re.compile(r'that says (" ?[^"]*"|\w+)')] - +PATTERN_REPLACE = [ + re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), + re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), + re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), + re.compile(r'with the number \d+ on (it|\w+ \w+)'), + re.compile(r'with the words "'), + re.compile(r'word \w+ on it'), + re.compile(r'that says the word \w+ on it'), + re.compile('that says\'the word "( on it)?'), +] # 誤検知しまくりの with the word xxxx を消す + + def remove_words(captions, debug): removed_caps = [] for caption in captions: cap = caption for pat in PATTERN_REPLACE: - cap = pat.sub("", caption) + cap = pat.sub("", cap) if debug and cap != caption: print(caption) print(cap) @@ -87,7 +96,7 @@ def main(args): if args.max_data_loader_n_workers is not None: dataset = train_util.ImageLoadingDataset(image_paths) data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) else: data = [[(None, ip)] for ip in image_paths] @@ -96,7 +105,7 @@ def main(args): for data in data_entry: if data is None: continue - + image, image_path = data if image is None: try: From 76f53429be917202b646b512f6c6bc8a641f13cf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 21:05:14 +0900 Subject: [PATCH 10/14] Fix existing npz skip feature --- finetune/prepare_buckets_latents.py | 43 +++++++++++++++-------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 537626d8..d1b9ea25 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -63,25 +63,6 @@ def main(args): print(f"no metadata / メタデータファイルがありません: {args.in_json}") return - # 既に存在するファイルをfilterする - if args.skip_existing: - filtered = [] - for image_path in image_paths: - image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] - - npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz" - if os.path.exists(npz_file_name_flip): - if not args.flip_aug: - continue - - npz_file_name_flip = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" - if os.path.exists(npz_file_name_flip): - continue - - filtered.apppend(image_path) - print(f"number of skipped images (npz already exists) / npzファイルが存在するためスキップした画像数: {len(image_paths) - len(filtered)}") - image_paths = filtered - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -128,8 +109,8 @@ def main(args): # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: dataset = train_util.ImageLoadingDataset(image_paths) - data = torch.util.data.DataLoader(dataset, batch_size=1, shuffle=False, - num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) + data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, + num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) else: data = [[(None, ip)] for ip in image_paths] @@ -153,6 +134,7 @@ def main(args): if image_key not in metadata: metadata[image_key] = {} + # 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変 aspect_ratio = image.width / image.height ar_errors = bucket_aspect_ratios - aspect_ratio bucket_id = np.abs(ar_errors).argmin() @@ -176,6 +158,25 @@ def main(args): assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + # 既に存在するファイルがあればshapeを確認して同じならskipする + if args.skip_existing: + npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] + if args.flip_aug: + npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") + + found = True + for npz_file in npz_files: + if not os.path.exists(npz_file): + found = False + break + + dat = np.load(npz_file)['arr_0'] + if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 + found = False + break + if found: + continue + # 画像をリサイズしてトリミングする # PILにinter_areaがないのでcv2で…… image = np.array(image) From 1bec2bfe07c2d6e8cccd4b41e175ad06e82941dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 21:05:55 +0900 Subject: [PATCH 11/14] Add cleaning duplicated tags --- finetune/clean_captions_and_tags.py | 65 ++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/finetune/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py index 8f53737d..11a59b1f 100644 --- a/finetune/clean_captions_and_tags.py +++ b/finetune/clean_captions_and_tags.py @@ -5,13 +5,32 @@ import argparse import glob import os import json +import re from tqdm import tqdm +PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') +PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') +PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') +PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') + +# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する +PATTERNS_REMOVE_IN_MULTI = [ + PATTERN_HAIR_LENGTH, + PATTERN_HAIR_CUT, + re.compile(r', [\w\-]+ eyes, '), + re.compile(r', ([\w\-]+ sleeves|sleeveless), '), + # 複数の髪型定義がある場合は削除する + re.compile( + r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), +] + def clean_tags(image_key, tags): # replace '_' to ' ' + tags = tags.replace('^_^', '^@@@^') tags = tags.replace('_', ' ') + tags = tags.replace('^@@@^', '^_^') # remove rating: deepdanbooruのみ tokens = tags.split(", rating") @@ -26,6 +45,37 @@ def clean_tags(image_key, tags): print(f"{image_key} {tags}") tags = tokens[0] + tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 + + # 複数の人物がいる場合は髪色等のタグを削除する + if 'girls' in tags or 'boys' in tags: + for pat in PATTERNS_REMOVE_IN_MULTI: + found = pat.findall(tags) + if len(found) > 1: # 二つ以上、タグがある + tags = pat.sub("", tags) + + # 髪の特殊対応 + srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) + if srch_hair_len: + org = srch_hair_len.group() + tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) + + found = PATTERN_HAIR.findall(tags) + if len(found) > 1: + tags = PATTERN_HAIR.sub("", tags) + + if srch_hair_len: + tags = tags.replace(", @@@, ", org) # 戻す + + # white shirtとshirtみたいな重複タグの削除 + found = PATTERN_WORD.findall(tags) + for word in found: + if re.search(f", ((\w+) )+{word}, ", tags): + tags = tags.replace(f", {word}, ", "") + + tags = tags.replace(", , ", ", ") + assert tags.startswith(", ") and tags.endswith(", ") + tags = tags[2:-2] return tags @@ -88,13 +138,23 @@ def main(args): if tags is None: print(f"image does not have tags / メタデータにタグがありません: {image_key}") else: - metadata[image_key]['tags'] = clean_tags(image_key, tags) + org = tags + tags = clean_tags(image_key, tags) + metadata[image_key]['tags'] = tags + if args.debug and org != tags: + print("FROM: " + org) + print("TO: " + tags) caption = metadata[image_key].get('caption') if caption is None: print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") else: - metadata[image_key]['caption'] = clean_caption(caption) + org = caption + caption = clean_caption(caption) + metadata[image_key]['caption'] = caption + if args.debug and org != caption: + print("FROM: " + org) + print("TO: " + caption) # metadataを書き出して終わり print(f"writing metadata: {args.out_json}") @@ -108,6 +168,7 @@ if __name__ == '__main__': # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument("--debug", action="store_true", help="debug mode") args, unknown = parser.parse_known_args() if len(unknown) == 1: From 26efa88908a184eb48b990b92f043193c0d0bf01 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:02:49 +0900 Subject: [PATCH 12/14] Update README.md --- README.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/README.md b/README.md index 31c08a6e..7dd91db7 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,37 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Note: The LoRA models for SD 2.x is not supported too in Web UI. +- 3 Feb. 2023, 2023/2/3 + - Update finetune preprocessing scripts. + - ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev! + - The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work. + - To change the weight, remove ``wd14_tagger_model`` folder, and run the script again. + - ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster. + - Please specify 2 or 4, depends on the number of CPU cores. + - ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``. + - ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning. + - Usage is almost the same as ``make_captions.py``, but batch size should be smaller. + - ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``). + - ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option. + - ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed. + - Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko! + - __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option. + + - fine tuning用の前処理スクリプト群を更新しました。 + - 拡張子 ``.bmp`` と ``.jpeg`` をサポートしました。breakcore2氏およびp1atdev氏に感謝します。 + - ``tag_images_by_wd14_tagger.py`` のデフォルトの重みを ``SmilingWolf/wd-v1-4-convnext-tagger-v2`` に更新しました。他の ``SmilingWolf`` 氏の重みも ``--repo_id`` オプションで指定可能です。SmilingWolf氏に感謝します。 + - 重みを変更するときには ``wd14_tagger_model`` フォルダを削除してからスクリプトを再実行してください。 + - ``--max_data_loader_n_workers`` オプションが各スクリプトに追加されました。DataLoaderを用いることで読み込み処理を並列化し、処理を20~30%程度高速化します。 + - CPUのコア数に応じて2~4程度の値を指定してください。 + - ``--recursive`` オプションを ``merge_dd_tags_to_metadata.py`` と ``merge_captions_to_metadata.py`` に追加しました。``--full_path`` を指定したときのみ使用可能です。 + - ``make_captions_by_git.py`` を追加しました。[GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) を用いてキャプションニングを行います。 + - 使用法は ``make_captions.py``とほぼ同じですがバッチサイズは小さめにしてください。 + - ``--remove_words`` オプションを指定するとテキスト読み取りを可能な限り削除します(``the word "XXXX" on it``のようなもの)。 + - ``--skip_existing`` を ``prepare_buckets_latents.py`` に追加しました。すでにnpzファイルがある画像の処理をスキップします。 + - ``clean_captions_and_tags.py``を重複タグや矛盾するタグを削除するよう機能追加しました。例:``white shirt`` タグがある場合、 ``shirt`` タグは削除されます。また``black hair``と``red hair``の両方がある場合、両方とも削除されます。 + - ``train_network.py``で使用されているタグと回数をメタデータに記録するようになりました。space-nuko氏に感謝します。 + - __すべてのタグと回数がメタデータに記録されます__ 望まない場合には``--no_metadata option``オプションでメタデータの記録を停止してください。 + - 29 Jan. 2023, 2023/1/29 - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` From b18a09edb50d5cfdf9502d7cfca24626f25a846e Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:09:55 +0900 Subject: [PATCH 13/14] Update README.md --- README.md | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 7dd91db7..f605f62a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Note: The LoRA models for SD 2.x is not supported too in Web UI. - Please specify 2 or 4, depends on the number of CPU cores. - ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``. - ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning. + - ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade). - Usage is almost the same as ``make_captions.py``, but batch size should be smaller. - ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``). - ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option. @@ -30,6 +31,7 @@ Note: The LoRA models for SD 2.x is not supported too in Web UI. - CPUのコア数に応じて2~4程度の値を指定してください。 - ``--recursive`` オプションを ``merge_dd_tags_to_metadata.py`` と ``merge_captions_to_metadata.py`` に追加しました。``--full_path`` を指定したときのみ使用可能です。 - ``make_captions_by_git.py`` を追加しました。[GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) を用いてキャプションニングを行います。 + - ``requirements.txt`` が更新されていますので、[ライブラリをアップデート](https://github.com/kohya-ss/sd-scripts/blob/main/README-ja.md#%E3%82%A2%E3%83%83%E3%83%97%E3%82%B0%E3%83%AC%E3%83%BC%E3%83%89)してください。 - 使用法は ``make_captions.py``とほぼ同じですがバッチサイズは小さめにしてください。 - ``--remove_words`` オプションを指定するとテキスト読み取りを可能な限り削除します(``the word "XXXX" on it``のようなもの)。 - ``--skip_existing`` を ``prepare_buckets_latents.py`` に追加しました。すでにnpzファイルがある画像の処理をスキップします。 @@ -37,27 +39,6 @@ Note: The LoRA models for SD 2.x is not supported too in Web UI. - ``train_network.py``で使用されているタグと回数をメタデータに記録するようになりました。space-nuko氏に感謝します。 - __すべてのタグと回数がメタデータに記録されます__ 望まない場合には``--no_metadata option``オプションでメタデータの記録を停止してください。 -- 29 Jan. 2023, 2023/1/29 - - Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev! - - Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py`` - - ``--lr_scheduler_num_cycles`` と ``--lr_scheduler_power`` オプションを ``train_network.py`` に追加しました。前者は cosine_with_restarts、後者は polynomial の学習率スケジューラに有効です。mgz-dev氏に感謝します。 - - ``convert_diffusers20_original_sd.py`` で SD 形式から Diffusers に変換するときの U-Net の ``sample_size`` パラメータを ``64`` に修正しました。 -- 26 Jan. 2023, 2023/1/26 - - Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.) - - Textual Inversionの学習をサポートしました。ドキュメントは[こちら](./train_ti_README-ja.md)。 -- 24 Jan. 2023, 2023/1/24 - - Change the default save format to ``.safetensors`` for ``train_network.py``. - - Add ``--save_n_epoch_ratio`` option to specify how often to save. Thanks to forestsource! - - For example, if 5 is specified, 5 (or 6) files will be saved in training. - - Add feature to pre-calculate hash to reduce loading time in the extension. Thanks to space-nuko! - - Add bucketing metadata. Thanks to space-nuko! - - Fix an error with bf16 model in ``gen_img_diffusers.py``. - - ``train_network.py`` のモデル保存形式のデフォルトを ``.safetensors`` に変更しました。 - - モデルを保存する頻度を指定する ``--save_n_epoch_ratio`` オプションが追加されました。forestsource氏に感謝します。 - - たとえば 5 を指定すると、学習終了までに合計で5個(または6個)のファイルが保存されます。 - - 拡張でモデル読み込み時間を短縮するためのハッシュ事前計算の機能を追加しました。space-nuko氏に感謝します。 - - メタデータにbucket情報が追加されました。space-nuko氏に感謝します。 - - ``gen_img_diffusers.py`` でbf16形式のモデルを読み込んだときのエラーを修正しました。 Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。 From 9682772b09a37d42dbf9963153a7ceaabcff6182 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:10:17 +0900 Subject: [PATCH 14/14] Update README-ja.md --- README-ja.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README-ja.md b/README-ja.md index c661a168..2427652c 100644 --- a/README-ja.md +++ b/README-ja.md @@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。(bf1 cd sd-scripts git pull .\venv\Scripts\activate -pip install --upgrade -r +pip install --upgrade -r requirements.txt ``` コマンドが成功すれば新しいバージョンが使用できます。