mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add tag freq for FinetuneDataset
This commit is contained in:
@@ -44,7 +44,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
|
|||||||
|
|
||||||
# region dataset
|
# region dataset
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
||||||
|
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
||||||
|
|
||||||
|
|
||||||
class ImageInfo():
|
class ImageInfo():
|
||||||
@@ -116,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
|
def set_tag_frequency(self, dir_name, captions):
|
||||||
|
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||||
|
self.tag_frequency[dir_name] = frequency_for_dir
|
||||||
|
for caption in captions:
|
||||||
|
for tag in caption.split(","):
|
||||||
|
if tag and not tag.isspace():
|
||||||
|
tag = tag.lower()
|
||||||
|
frequency = frequency_for_dir.get(tag, 0)
|
||||||
|
frequency_for_dir[tag] = frequency + 1
|
||||||
|
|
||||||
def disable_token_padding(self):
|
def disable_token_padding(self):
|
||||||
self.token_padding_disabled = True
|
self.token_padding_disabled = True
|
||||||
|
|
||||||
@@ -545,14 +556,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
cap_for_img = read_caption(img_path)
|
cap_for_img = read_caption(img_path)
|
||||||
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
||||||
|
|
||||||
frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {})
|
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
||||||
self.tag_frequency[os.path.basename(dir)] = frequency_for_dir
|
|
||||||
for caption in captions:
|
|
||||||
for tag in caption.split(","):
|
|
||||||
if tag and not tag.isspace():
|
|
||||||
tag = tag.lower()
|
|
||||||
frequency = frequency_for_dir.get(tag, 0)
|
|
||||||
frequency_for_dir[tag] = frequency + 1
|
|
||||||
|
|
||||||
return n_repeats, img_paths, captions
|
return n_repeats, img_paths, captions
|
||||||
|
|
||||||
@@ -562,10 +566,13 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for dir in train_dirs:
|
for dir in train_dirs:
|
||||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
||||||
num_train_images += n_repeats * len(img_paths)
|
num_train_images += n_repeats * len(img_paths)
|
||||||
|
|
||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||||
self.register_image(info)
|
self.register_image(info)
|
||||||
|
|
||||||
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_train_images} train images with repeating.")
|
print(f"{num_train_images} train images with repeating.")
|
||||||
self.num_train_images = num_train_images
|
self.num_train_images = num_train_images
|
||||||
|
|
||||||
@@ -579,9 +586,11 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for dir in reg_dirs:
|
for dir in reg_dirs:
|
||||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
||||||
num_reg_images += n_repeats * len(img_paths)
|
num_reg_images += n_repeats * len(img_paths)
|
||||||
|
|
||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||||
reg_infos.append(info)
|
reg_infos.append(info)
|
||||||
|
|
||||||
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_reg_images} reg images.")
|
print(f"{num_reg_images} reg images.")
|
||||||
@@ -626,6 +635,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.train_data_dir = train_data_dir
|
self.train_data_dir = train_data_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
tags_list = []
|
||||||
for image_key, img_md in metadata.items():
|
for image_key, img_md in metadata.items():
|
||||||
# path情報を作る
|
# path情報を作る
|
||||||
if os.path.exists(image_key):
|
if os.path.exists(image_key):
|
||||||
@@ -642,6 +652,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
caption = tags
|
caption = tags
|
||||||
elif tags is not None and len(tags) > 0:
|
elif tags is not None and len(tags) > 0:
|
||||||
caption = caption + ', ' + tags
|
caption = caption + ', ' + tags
|
||||||
|
tags_list.append(tags)
|
||||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
||||||
|
|
||||||
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
||||||
@@ -655,7 +666,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.num_train_images = len(metadata) * dataset_repeats
|
self.num_train_images = len(metadata) * dataset_repeats
|
||||||
self.num_reg_images = 0
|
self.num_reg_images = 0
|
||||||
|
|
||||||
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||||
|
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||||
|
|
||||||
# check existence of all npz files
|
# check existence of all npz files
|
||||||
if not self.color_aug:
|
if not self.color_aug:
|
||||||
@@ -676,6 +688,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
|
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
|
||||||
elif not npz_all:
|
elif not npz_all:
|
||||||
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
||||||
|
if self.flip_aug:
|
||||||
|
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||||
for image_info in self.image_data.values():
|
for image_info in self.image_data.values():
|
||||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||||
|
|
||||||
@@ -772,8 +786,8 @@ def glob_images(directory, base="*"):
|
|||||||
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
||||||
else:
|
else:
|
||||||
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
||||||
img_paths = list(set(img_paths)) # 重複を排除
|
# img_paths = list(set(img_paths)) # 重複を排除
|
||||||
img_paths.sort()
|
# img_paths.sort()
|
||||||
return img_paths
|
return img_paths
|
||||||
|
|
||||||
|
|
||||||
@@ -785,8 +799,8 @@ def glob_images_pathlib(dir_path, recursive):
|
|||||||
else:
|
else:
|
||||||
for ext in IMAGE_EXTENSIONS:
|
for ext in IMAGE_EXTENSIONS:
|
||||||
image_paths += list(dir_path.glob('*' + ext))
|
image_paths += list(dir_path.glob('*' + ext))
|
||||||
image_paths = list(set(image_paths)) # 重複を排除
|
# image_paths = list(set(image_paths)) # 重複を排除
|
||||||
image_paths.sort()
|
# image_paths.sort()
|
||||||
return image_paths
|
return image_paths
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
Reference in New Issue
Block a user