From 93134cdd1595339735e943fa7f22627e02c4ce68 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 Feb 2023 21:03:42 +0900 Subject: [PATCH] Add tag freq for FinetuneDataset --- library/train_util.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 459b81a1..c1e54517 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -44,7 +44,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset -IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] +# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux? class ImageInfo(): @@ -116,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + if tag and not tag.isspace(): + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + def disable_token_padding(self): self.token_padding_disabled = True @@ -545,14 +556,7 @@ class DreamBoothDataset(BaseDataset): cap_for_img = read_caption(img_path) captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {}) - self.tag_frequency[os.path.basename(dir)] = frequency_for_dir - for caption in captions: - for tag in caption.split(","): - if tag and not tag.isspace(): - tag = tag.lower() - frequency = frequency_for_dir.get(tag, 0) - frequency_for_dir[tag] = frequency + 1 + self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録 return n_repeats, img_paths, captions @@ -562,10 +566,13 @@ class DreamBoothDataset(BaseDataset): for dir in train_dirs: n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) num_train_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) + self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} + print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -579,9 +586,11 @@ class DreamBoothDataset(BaseDataset): for dir in reg_dirs: n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) num_reg_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) + self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") @@ -626,6 +635,7 @@ class FineTuningDataset(BaseDataset): self.train_data_dir = train_data_dir self.batch_size = batch_size + tags_list = [] for image_key, img_md in metadata.items(): # path情報を作る if os.path.exists(image_key): @@ -642,6 +652,7 @@ class FineTuningDataset(BaseDataset): caption = tags elif tags is not None and len(tags) > 0: caption = caption + ', ' + tags + tags_list.append(tags) assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) @@ -655,7 +666,8 @@ class FineTuningDataset(BaseDataset): self.num_train_images = len(metadata) * dataset_repeats self.num_reg_images = 0 - self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} + self.set_tag_frequency(os.path.basename(json_file_name), tags_list) + self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} # check existence of all npz files if not self.color_aug: @@ -676,6 +688,8 @@ class FineTuningDataset(BaseDataset): print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") elif not npz_all: print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + if self.flip_aug: + print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") for image_info in self.image_data.values(): image_info.latents_npz = image_info.latents_npz_flipped = None @@ -772,8 +786,8 @@ def glob_images(directory, base="*"): img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - img_paths = list(set(img_paths)) # 重複を排除 - img_paths.sort() + # img_paths = list(set(img_paths)) # 重複を排除 + # img_paths.sort() return img_paths @@ -785,8 +799,8 @@ def glob_images_pathlib(dir_path, recursive): else: for ext in IMAGE_EXTENSIONS: image_paths += list(dir_path.glob('*' + ext)) - image_paths = list(set(image_paths)) # 重複を排除 - image_paths.sort() + # image_paths = list(set(image_paths)) # 重複を排除 + # image_paths.sort() return image_paths # endregion