Add tag freq for FinetuneDataset

This commit is contained in:
Kohya S
2023-02-03 21:03:42 +09:00
parent 57d8483eaf
commit 93134cdd15

View File

@@ -44,7 +44,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo():
@@ -116,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self):
self.token_padding_disabled = True
@@ -545,14 +556,7 @@ class DreamBoothDataset(BaseDataset):
cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {})
self.tag_frequency[os.path.basename(dir)] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
return n_repeats, img_paths, captions
@@ -562,10 +566,13 @@ class DreamBoothDataset(BaseDataset):
for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
@@ -579,9 +586,11 @@ class DreamBoothDataset(BaseDataset):
for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.")
@@ -626,6 +635,7 @@ class FineTuningDataset(BaseDataset):
self.train_data_dir = train_data_dir
self.batch_size = batch_size
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
if os.path.exists(image_key):
@@ -642,6 +652,7 @@ class FineTuningDataset(BaseDataset):
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
@@ -655,7 +666,8 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files
if not self.color_aug:
@@ -676,6 +688,8 @@ class FineTuningDataset(BaseDataset):
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
elif not npz_all:
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
if self.flip_aug:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
@@ -772,8 +786,8 @@ def glob_images(directory, base="*"):
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths)) # 重複を排除
img_paths.sort()
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
return img_paths
@@ -785,8 +799,8 @@ def glob_images_pathlib(dir_path, recursive):
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob('*' + ext))
image_paths = list(set(image_paths)) # 重複を排除
image_paths.sort()
# image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort()
return image_paths
# endregion