add detail dataset config feature by extra config file (#227)

* add config file schema

* change config file specification

* refactor config utility

* unify batch_size to train_batch_size

* fix indent size

* use batch_size instead of train_batch_size

* make cache_latents configurable on subset

* rename options
* bucket_repo_range
* shuffle_keep_tokens

* update readme

* revert to min_bucket_reso & max_bucket_reso

* use subset structure in dataset

* format import lines

* split mode specific options

* use only valid subset

* change valid subsets name

* manage multiple datasets by dataset group

* update config file sanitizer

* prune redundant validation

* add comments

* update type annotation

* rename json_file_name to metadata_file

* ignore when image dir is invalid

* fix tag shuffle and dropout

* ignore duplicated subset

* add method to check latent cachability

* fix format

* fix bug

* update caption dropout default values

* update annotation

* fix bug

* add option to enable bucket shuffle across dataset

* update blueprint generate function

* use blueprint generator for dataset initialization

* delete duplicated function

* update config readme

* delete debug print

* print dataset and subset info as info

* enable bucket_shuffle_across_dataset option

* update config readme for clarification

* compensate quotes for string option example

* fix bug of bad usage of join

* conserve trained metadata backward compatibility

* enable shuffle in data loader by default

* delete resolved TODO

* add comment for image data handling

* fix reference bug

* fix undefined variable bug

* prevent raise overwriting

* assert image_dir and metadata_file validity

* add debug message for ignoring subset

* fix inconsistent import statement

* loosen too strict validation on float value

* sanitize argument parser separately

* make image_dir optional for fine tuning dataset

* fix import

* fix trailing characters in print

* parse flexible dataset config deterministically

* use relative import

* print supplementary message for parsing error

* add note about different methods

* add note of benefit of separate dataset

* add error example

* add note for english readme plan

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
fur0ut0
2023-03-01 20:58:08 +09:00
committed by GitHub
parent 82707654ad
commit 8abb8645ae
8 changed files with 1370 additions and 321 deletions

View File

@@ -6,8 +6,15 @@ import json
import re
import shutil
import time
from typing import Dict, List, NamedTuple, Tuple
from typing import Optional, Union
from typing import (
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
from accelerate import Accelerator
import glob
import math
@@ -203,23 +210,93 @@ class BucketBatchIndex(NamedTuple):
batch_index: int
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
super().__init__()
self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length
class AugHelper:
def __init__(self):
# prepare all possible augmentators
color_aug_method = albu.OneOf([
albu.HueSaturationValue(8, 0, 0, p=.5),
albu.RandomGamma((95, 105), p=.5),
], p=.33)
flip_aug_method = albu.HorizontalFlip(p=0.5)
# key: (use_color_aug, use_flip_aug)
self.augmentors = {
(True, True): albu.Compose([
color_aug_method,
flip_aug_method,
], p=1.),
(True, False): albu.Compose([
color_aug_method,
], p=1.),
(False, True): albu.Compose([
flip_aug_method,
], p=1.),
(False, False): None
}
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
return self.augmentors[(use_color_aug, use_flip_aug)]
class BaseSubset:
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
self.shuffle_caption = shuffle_caption
self.shuffle_keep_tokens = shuffle_keep_tokens
self.keep_tokens = keep_tokens
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
self.random_crop = random_crop
self.caption_dropout_rate = caption_dropout_rate
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.img_count = 0
class DreamBoothSubset(BaseSubset):
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
return NotImplemented
return self.image_dir == other.image_dir
class FineTuningSubset(BaseSubset):
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
self.metadata_file = metadata_file
def __eq__(self, other) -> bool:
if not isinstance(other, FineTuningSubset):
return NotImplemented
return self.metadata_file == other.metadata_file
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
super().__init__()
self.tokenizer = tokenizer
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.face_crop_aug_range = face_crop_aug_range
self.flip_aug = flip_aug
self.color_aug = color_aug
self.debug_dataset = debug_dataset
self.random_crop = random_crop
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
self.token_padding_disabled = False
self.dataset_dirs_info = {}
self.reg_dataset_dirs_info = {}
self.tag_frequency = {}
self.enable_bucket = False
@@ -233,42 +310,20 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.tag_dropout_rate: float = 0
# augmentation
flip_p = 0.5 if flip_aug else 0.0
if color_aug:
# わりと弱めの色合いaugmentationbrightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
self.aug = albu.Compose([
albu.OneOf([
albu.HueSaturationValue(8, 0, 0, p=.5),
albu.RandomGamma((95, 105), p=.5),
], p=.33),
albu.HorizontalFlip(p=flip_p)
], p=1.)
elif flip_aug:
self.aug = albu.Compose([
albu.HorizontalFlip(p=flip_p)
], p=1.)
else:
self.aug = None
self.aug_helper = AugHelper()
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
self.replacements = {}
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs
self.tag_dropout_rate = tag_dropout_rate
self.shuffle_buckets()
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -286,42 +341,36 @@ class BaseDataset(torch.utils.data.Dataset):
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
def process_caption(self, caption):
def process_caption(self, subset: BaseSubset, caption):
# dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
if is_drop_out:
caption = ""
else:
if self.shuffle_caption or self.tag_dropout_rate > 0:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
if subset.caption_tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= self.tag_dropout_rate:
if random.random() >= subset.caption_tag_dropout_rate:
l.append(token)
return l
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
if self.shuffle_caption:
random.shuffle(tokens)
fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")]
if subset.keep_tokens >= 0:
fixed_tokens = flex_tokens[:subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens:]
tokens = dropout_tags(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
if subset.shuffle_caption:
random.shuffle(flex_tokens)
if self.shuffle_caption:
random.shuffle(tokens)
flex_tokens = dropout_tags(flex_tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
caption = ", ".join(fixed_tokens + flex_tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
@@ -375,8 +424,9 @@ class BaseDataset(torch.utils.data.Dataset):
input_ids = torch.stack(iids_list) # 3,77
return input_ids
def register_image(self, info: ImageInfo):
def register_image(self, info: ImageInfo, subset: BaseSubset):
self.image_data[info.image_key] = info
self.image_to_subset[info.image_key] = subset
def make_buckets(self):
'''
@@ -475,7 +525,7 @@ class BaseDataset(torch.utils.data.Dataset):
img = np.array(image, np.uint8)
return img
def trim_and_resize_if_required(self, image, reso, resized_size):
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -485,22 +535,27 @@ class BaseDataset(torch.utils.data.Dataset):
image_height, image_width = image.shape[0:2]
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p:p + reso[0]]
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p:p + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image
def is_latent_cachable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae):
# TODO ここを高速化したい
print("caching latents.")
for info in tqdm(self.image_data.values()):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
@@ -510,13 +565,13 @@ class BaseDataset(torch.utils.data.Dataset):
continue
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
if self.flip_aug:
if subset.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -526,11 +581,11 @@ class BaseDataset(torch.utils.data.Dataset):
image = Image.open(image_path)
return image.size
def load_image_with_face_info(self, image_path: str):
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = self.load_image(image_path)
face_cx = face_cy = face_w = face_h = 0
if self.face_crop_aug_range is not None:
if subset.face_crop_aug_range is not None:
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
if len(tokens) >= 5:
face_cx = int(tokens[-4])
@@ -541,7 +596,7 @@ class BaseDataset(torch.utils.data.Dataset):
return img, face_cx, face_cy, face_w, face_h
# いい感じに切り出す
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
height, width = image.shape[0:2]
if height == self.height and width == self.width:
return image
@@ -549,8 +604,8 @@ class BaseDataset(torch.utils.data.Dataset):
# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
@@ -568,13 +623,13 @@ class BaseDataset(torch.utils.data.Dataset):
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
if self.random_crop:
if subset.random_crop:
# 背景も含めるために顔を中心に置く確率を高めつつずらす
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
else:
# range指定があるときのみ、すこしだけランダムにわりと適当
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > self.size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
@@ -597,9 +652,6 @@ class BaseDataset(torch.utils.data.Dataset):
return self._length
def __getitem__(self, index):
if index == 0:
self.shuffle_buckets()
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
@@ -612,28 +664,29 @@ class BaseDataset(torch.utils.data.Dataset):
for image_key in bucket[image_index:image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
# image/latentsを処理する
if image_info.latents is not None:
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None:
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
elif im_h > self.height or im_w > self.width:
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
if im_h > self.height:
p = random.randint(0, im_h - self.height)
img = img[p:p + self.height]
@@ -645,8 +698,9 @@ class BaseDataset(torch.utils.data.Dataset):
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# augmentation
if self.aug is not None:
img = self.aug(image=img)['image']
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
if aug is not None:
img = aug(image=img)['image']
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
@@ -654,7 +708,7 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image)
latents_list.append(latents)
caption = self.process_caption(image_info.caption)
caption = self.process_caption(subset, image_info.caption)
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
@@ -685,9 +739,8 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -710,7 +763,7 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path):
def read_caption(img_path, caption_extension):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
@@ -733,153 +786,170 @@ class DreamBoothDataset(BaseDataset):
break
return caption
def load_dreambooth_dir(dir):
if not os.path.isdir(dir):
# print(f"ignore file: {dir}")
return 0, [], []
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
tokens = os.path.basename(dir).split('_')
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
return 0, [], []
caption_by_folder = '_'.join(tokens[1:])
img_paths = glob_images(dir, "*")
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
img_paths = glob_images(subset.image_dir, "*")
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
captions.append("")
else:
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
return img_paths, captions
return n_repeats, img_paths, captions
print("prepare train images.")
train_dirs = os.listdir(train_data_dir)
print("prepare images.")
num_train_images = 0
for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths)
num_reg_images = 0
reg_infos: List[ImageInfo] = []
for subset in subsets:
if subset.num_repeats < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
continue
if subset in self.subsets:
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
continue
img_paths, captions = load_dreambooth_dir(subset)
if len(img_paths) < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
continue
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
else:
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if subset.is_reg:
reg_infos.append(info)
else:
self.register_image(info, subset)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
# reg imageは数を数えて学習画像と同じ枚数にする
num_reg_images = 0
if reg_data_dir:
print("prepare reg images.")
reg_infos: List[ImageInfo] = []
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
reg_dirs = os.listdir(reg_data_dir)
for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
if num_reg_images == 0:
print("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算するどうせ大した数ではないのでループで処理する
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
if first_loop:
self.register_image(info)
n += info.num_repeats
else:
info.num_repeats += 1
n += 1
if n >= num_train_images:
break
first_loop = False
if num_reg_images == 0:
print("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算するどうせ大した数ではないのでループで処理する
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1
n += 1
if n >= num_train_images:
break
first_loop = False
self.num_reg_images = num_reg_images
class FineTuningDataset(BaseDataset):
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
# メタデータを読み込む
if os.path.exists(json_file_name):
print(f"loading existing metadata: {json_file_name}")
with open(json_file_name, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
self.metadata = metadata
self.train_data_dir = train_data_dir
self.batch_size = batch_size
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
if os.path.exists(image_key):
abs_path = image_key
else:
# わりといい加減だがいい方法が思いつかん
abs_path = glob_images(train_data_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
abs_path = abs_path[0]
caption = img_md.get('caption')
tags = img_md.get('tags')
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
image_info.image_size = img_md.get('train_resolution')
if not self.color_aug and not self.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
self.register_image(image_info)
self.num_train_images = len(metadata) * dataset_repeats
self.num_train_images = 0
self.num_reg_images = 0
# TODO do not record tag freq when no tag
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
for subset in subsets:
if subset.num_repeats < 1:
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
continue
if subset in self.subsets:
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
continue
# メタデータを読み込む
if os.path.exists(subset.metadata_file):
print(f"loading existing metadata: {subset.metadata_file}")
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
if len(metadata) < 1:
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
continue
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
if os.path.exists(image_key):
abs_path = image_key
else:
# わりといい加減だがいい方法が思いつかん
abs_path = glob_images(subset.image_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
abs_path = abs_path[0]
caption = img_md.get('caption')
tags = img_md.get('tags')
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info.image_size = img_md.get('train_resolution')
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
self.register_image(image_info, subset)
self.num_train_images += len(metadata) * subset.num_repeats
# TODO do not record tag freq when no tag
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
subset.img_count = len(metadata)
self.subsets.append(subset)
# check existence of all npz files
use_npz_latents = not (self.color_aug or self.random_crop)
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
flip_aug_in_subset = False
npz_any = False
npz_all = True
for image_info in self.image_data.values():
subset = self.image_to_subset[image_info.image_key]
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
if self.flip_aug:
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
@@ -891,7 +961,7 @@ class FineTuningDataset(BaseDataset):
elif not npz_all:
use_npz_latents = False
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
if self.flip_aug:
if flip_aug_in_subset:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -937,7 +1007,7 @@ class FineTuningDataset(BaseDataset):
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, image_key):
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + '.npz'
@@ -949,8 +1019,8 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
# image_key is relative path
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
if not os.path.exists(npz_file_norm):
npz_file_norm = None
@@ -961,6 +1031,49 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
# behave as Dataset mock
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
super().__init__(datasets)
self.image_data = {}
self.num_train_images = 0
self.num_reg_images = 0
# simply concat together
# TODO: handling image_data key duplication among dataset
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
for dataset in datasets:
self.image_data.update(dataset.image_data)
self.num_train_images += dataset.num_train_images
self.num_reg_images += dataset.num_reg_images
def add_replacement(self, str_from, str_to):
for dataset in self.datasets:
dataset.add_replacement(str_from, str_to)
def make_buckets(self):
for dataset in self.datasets:
dataset.make_buckets()
def cache_latents(self, vae):
for dataset in self.datasets:
dataset.cache_latents(vae)
def is_latent_cachable(self) -> bool:
return all([dataset.is_latent_cachable() for dataset in self.datasets])
def set_current_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
@@ -1489,7 +1602,7 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子スペルミスを残してあります")
parser.add_argument("--keep_tokens", type=int, default=None,
parser.add_argument("--keep_tokens", type=int, default=0,
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
@@ -1515,11 +1628,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
parser.add_argument("--caption_dropout_rate", type=float, default=0,
parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
if support_dreambooth:
@@ -1787,10 +1900,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
args.caption_extension = args.caption_extention
args.caption_extention = None
if args.cache_latents:
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
# assert args.resolution is not None, f"resolution is required / resolution解像度を指定してください"
if args.resolution is not None:
args.resolution = tuple([int(r) for r in args.resolution.split(',')])