mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add detail dataset config feature by extra config file (#227)
* add config file schema * change config file specification * refactor config utility * unify batch_size to train_batch_size * fix indent size * use batch_size instead of train_batch_size * make cache_latents configurable on subset * rename options * bucket_repo_range * shuffle_keep_tokens * update readme * revert to min_bucket_reso & max_bucket_reso * use subset structure in dataset * format import lines * split mode specific options * use only valid subset * change valid subsets name * manage multiple datasets by dataset group * update config file sanitizer * prune redundant validation * add comments * update type annotation * rename json_file_name to metadata_file * ignore when image dir is invalid * fix tag shuffle and dropout * ignore duplicated subset * add method to check latent cachability * fix format * fix bug * update caption dropout default values * update annotation * fix bug * add option to enable bucket shuffle across dataset * update blueprint generate function * use blueprint generator for dataset initialization * delete duplicated function * update config readme * delete debug print * print dataset and subset info as info * enable bucket_shuffle_across_dataset option * update config readme for clarification * compensate quotes for string option example * fix bug of bad usage of join * conserve trained metadata backward compatibility * enable shuffle in data loader by default * delete resolved TODO * add comment for image data handling * fix reference bug * fix undefined variable bug * prevent raise overwriting * assert image_dir and metadata_file validity * add debug message for ignoring subset * fix inconsistent import statement * loosen too strict validation on float value * sanitize argument parser separately * make image_dir optional for fine tuning dataset * fix import * fix trailing characters in print * parse flexible dataset config deterministically * use relative import * print supplementary message for parsing error * add note about different methods * add note of benefit of separate dataset * add error example * add note for english readme plan --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,15 @@ import json
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
from typing import Optional, Union
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator
|
||||
import glob
|
||||
import math
|
||||
@@ -203,23 +210,93 @@ class BucketBatchIndex(NamedTuple):
|
||||
batch_index: int
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer: CLIPTokenizer = tokenizer
|
||||
self.max_token_length = max_token_length
|
||||
class AugHelper:
|
||||
def __init__(self):
|
||||
# prepare all possible augmentators
|
||||
color_aug_method = albu.OneOf([
|
||||
albu.HueSaturationValue(8, 0, 0, p=.5),
|
||||
albu.RandomGamma((95, 105), p=.5),
|
||||
], p=.33)
|
||||
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
||||
|
||||
# key: (use_color_aug, use_flip_aug)
|
||||
self.augmentors = {
|
||||
(True, True): albu.Compose([
|
||||
color_aug_method,
|
||||
flip_aug_method,
|
||||
], p=1.),
|
||||
(True, False): albu.Compose([
|
||||
color_aug_method,
|
||||
], p=1.),
|
||||
(False, True): albu.Compose([
|
||||
flip_aug_method,
|
||||
], p=1.),
|
||||
(False, False): None
|
||||
}
|
||||
|
||||
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
||||
return self.augmentors[(use_color_aug, use_flip_aug)]
|
||||
|
||||
|
||||
class BaseSubset:
|
||||
def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.num_repeats = num_repeats
|
||||
self.shuffle_caption = shuffle_caption
|
||||
self.shuffle_keep_tokens = shuffle_keep_tokens
|
||||
self.keep_tokens = keep_tokens
|
||||
self.color_aug = color_aug
|
||||
self.flip_aug = flip_aug
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
self.random_crop = random_crop
|
||||
self.caption_dropout_rate = caption_dropout_rate
|
||||
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
|
||||
self.caption_tag_dropout_rate = caption_tag_dropout_rate
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
||||
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
||||
|
||||
self.is_reg = is_reg
|
||||
self.class_tokens = class_tokens
|
||||
self.caption_extension = caption_extension
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DreamBoothSubset):
|
||||
return NotImplemented
|
||||
return self.image_dir == other.image_dir
|
||||
|
||||
class FineTuningSubset(BaseSubset):
|
||||
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
||||
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, FineTuningSubset):
|
||||
return NotImplemented
|
||||
return self.metadata_file == other.metadata_file
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
self.flip_aug = flip_aug
|
||||
self.color_aug = color_aug
|
||||
self.debug_dataset = debug_dataset
|
||||
self.random_crop = random_crop
|
||||
|
||||
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
||||
|
||||
self.token_padding_disabled = False
|
||||
self.dataset_dirs_info = {}
|
||||
self.reg_dataset_dirs_info = {}
|
||||
self.tag_frequency = {}
|
||||
|
||||
self.enable_bucket = False
|
||||
@@ -233,42 +310,20 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
self.tag_dropout_rate: float = 0
|
||||
|
||||
# augmentation
|
||||
flip_p = 0.5 if flip_aug else 0.0
|
||||
if color_aug:
|
||||
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
||||
self.aug = albu.Compose([
|
||||
albu.OneOf([
|
||||
albu.HueSaturationValue(8, 0, 0, p=.5),
|
||||
albu.RandomGamma((95, 105), p=.5),
|
||||
], p=.33),
|
||||
albu.HorizontalFlip(p=flip_p)
|
||||
], p=1.)
|
||||
elif flip_aug:
|
||||
self.aug = albu.Compose([
|
||||
albu.HorizontalFlip(p=flip_p)
|
||||
], p=1.)
|
||||
else:
|
||||
self.aug = None
|
||||
self.aug_helper = AugHelper()
|
||||
|
||||
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
||||
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||
self.tag_dropout_rate = tag_dropout_rate
|
||||
self.shuffle_buckets()
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
@@ -286,42 +341,36 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def add_replacement(self, str_from, str_to):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
def process_caption(self, subset: BaseSubset, caption):
|
||||
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
||||
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
||||
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
|
||||
is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
|
||||
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
||||
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
|
||||
def dropout_tags(tokens):
|
||||
if self.tag_dropout_rate <= 0:
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
l = []
|
||||
for token in tokens:
|
||||
if random.random() >= self.tag_dropout_rate:
|
||||
if random.random() >= subset.caption_tag_dropout_rate:
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
fixed_tokens = []
|
||||
flex_tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if subset.keep_tokens >= 0:
|
||||
fixed_tokens = flex_tokens[:subset.keep_tokens]
|
||||
flex_tokens = flex_tokens[subset.keep_tokens:]
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
flex_tokens = dropout_tags(flex_tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ", ".join(tokens)
|
||||
caption = ", ".join(fixed_tokens + flex_tokens)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
@@ -375,8 +424,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
return input_ids
|
||||
|
||||
def register_image(self, info: ImageInfo):
|
||||
def register_image(self, info: ImageInfo, subset: BaseSubset):
|
||||
self.image_data[info.image_key] = info
|
||||
self.image_to_subset[info.image_key] = subset
|
||||
|
||||
def make_buckets(self):
|
||||
'''
|
||||
@@ -475,7 +525,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
def trim_and_resize_if_required(self, image, reso, resized_size):
|
||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
@@ -485,22 +535,27 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
image_height, image_width = image.shape[0:2]
|
||||
if image_width > reso[0]:
|
||||
trim_size = image_width - reso[0]
|
||||
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||
# print("w", trim_size, p)
|
||||
image = image[:, p:p + reso[0]]
|
||||
if image_height > reso[1]:
|
||||
trim_size = image_height - reso[1]
|
||||
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
||||
# print("h", trim_size, p)
|
||||
image = image[p:p + reso[1]]
|
||||
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
return image
|
||||
|
||||
def is_latent_cachable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def cache_latents(self, vae):
|
||||
# TODO ここを高速化したい
|
||||
print("caching latents.")
|
||||
for info in tqdm(self.image_data.values()):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None:
|
||||
info.latents = self.load_latents_from_npz(info, False)
|
||||
info.latents = torch.FloatTensor(info.latents)
|
||||
@@ -510,13 +565,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
||||
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||
|
||||
if self.flip_aug:
|
||||
if subset.flip_aug:
|
||||
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
@@ -526,11 +581,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
image = Image.open(image_path)
|
||||
return image.size
|
||||
|
||||
def load_image_with_face_info(self, image_path: str):
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||
img = self.load_image(image_path)
|
||||
|
||||
face_cx = face_cy = face_w = face_h = 0
|
||||
if self.face_crop_aug_range is not None:
|
||||
if subset.face_crop_aug_range is not None:
|
||||
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
||||
if len(tokens) >= 5:
|
||||
face_cx = int(tokens[-4])
|
||||
@@ -541,7 +596,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return img, face_cx, face_cy, face_w, face_h
|
||||
|
||||
# いい感じに切り出す
|
||||
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
||||
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
|
||||
height, width = image.shape[0:2]
|
||||
if height == self.height and width == self.width:
|
||||
return image
|
||||
@@ -549,8 +604,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# 画像サイズはsizeより大きいのでリサイズする
|
||||
face_size = max(face_w, face_h)
|
||||
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
||||
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
||||
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
||||
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
||||
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
||||
if min_scale >= max_scale: # range指定がmin==max
|
||||
scale = min_scale
|
||||
else:
|
||||
@@ -568,13 +623,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
||||
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
||||
|
||||
if self.random_crop:
|
||||
if subset.random_crop:
|
||||
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
||||
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
||||
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
||||
else:
|
||||
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
||||
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
||||
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
|
||||
if face_size > self.size // 10 and face_size >= 40:
|
||||
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
||||
|
||||
@@ -597,9 +652,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index == 0:
|
||||
self.shuffle_buckets()
|
||||
|
||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||
@@ -612,28 +664,29 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None:
|
||||
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped
|
||||
image = None
|
||||
elif image_info.latents_npz is not None:
|
||||
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5)
|
||||
latents = torch.FloatTensor(latents)
|
||||
image = None
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||
elif im_h > self.height or im_w > self.width:
|
||||
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
||||
assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
||||
if im_h > self.height:
|
||||
p = random.randint(0, im_h - self.height)
|
||||
img = img[p:p + self.height]
|
||||
@@ -645,8 +698,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
# augmentation
|
||||
if self.aug is not None:
|
||||
img = self.aug(image=img)['image']
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)['image']
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
@@ -654,7 +708,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
caption = self.process_caption(image_info.caption)
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
input_ids_list.append(self.get_input_ids(caption))
|
||||
@@ -685,9 +739,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||
def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -710,7 +763,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path):
|
||||
def read_caption(img_path, caption_extension):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -733,153 +786,170 @@ class DreamBoothDataset(BaseDataset):
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_dreambooth_dir(dir):
|
||||
if not os.path.isdir(dir):
|
||||
# print(f"ignore file: {dir}")
|
||||
return 0, [], []
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
print(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
|
||||
tokens = os.path.basename(dir).split('_')
|
||||
try:
|
||||
n_repeats = int(tokens[0])
|
||||
except ValueError as e:
|
||||
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
||||
return 0, [], []
|
||||
|
||||
caption_by_folder = '_'.join(tokens[1:])
|
||||
img_paths = glob_images(dir, "*")
|
||||
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path)
|
||||
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
||||
captions.append("")
|
||||
else:
|
||||
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
||||
return img_paths, captions
|
||||
|
||||
return n_repeats, img_paths, captions
|
||||
|
||||
print("prepare train images.")
|
||||
train_dirs = os.listdir(train_data_dir)
|
||||
print("prepare images.")
|
||||
num_train_images = 0
|
||||
for dir in train_dirs:
|
||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
||||
num_train_images += n_repeats * len(img_paths)
|
||||
num_reg_images = 0
|
||||
reg_infos: List[ImageInfo] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||
continue
|
||||
|
||||
img_paths, captions = load_dreambooth_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
||||
continue
|
||||
|
||||
if subset.is_reg:
|
||||
num_reg_images += subset.num_repeats * len(img_paths)
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||
self.register_image(info)
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if subset.is_reg:
|
||||
reg_infos.append(info)
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
# reg imageは数を数えて学習画像と同じ枚数にする
|
||||
num_reg_images = 0
|
||||
if reg_data_dir:
|
||||
print("prepare reg images.")
|
||||
reg_infos: List[ImageInfo] = []
|
||||
print(f"{num_reg_images} reg images.")
|
||||
if num_train_images < num_reg_images:
|
||||
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
reg_dirs = os.listdir(reg_data_dir)
|
||||
for dir in reg_dirs:
|
||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
||||
num_reg_images += n_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||
reg_infos.append(info)
|
||||
|
||||
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
|
||||
print(f"{num_reg_images} reg images.")
|
||||
if num_train_images < num_reg_images:
|
||||
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
if num_reg_images == 0:
|
||||
print("no regularization images / 正則化画像が見つかりませんでした")
|
||||
else:
|
||||
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
if num_reg_images == 0:
|
||||
print("no regularization images / 正則化画像が見つかりませんでした")
|
||||
else:
|
||||
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||
def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(json_file_name):
|
||||
print(f"loading existing metadata: {json_file_name}")
|
||||
with open(json_file_name, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
||||
|
||||
self.metadata = metadata
|
||||
self.train_data_dir = train_data_dir
|
||||
self.batch_size = batch_size
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(train_data_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||
abs_path = abs_path[0]
|
||||
|
||||
caption = img_md.get('caption')
|
||||
tags = img_md.get('tags')
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ', ' + tags
|
||||
tags_list.append(tags)
|
||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
||||
|
||||
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get('train_resolution')
|
||||
|
||||
if not self.color_aug and not self.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
||||
|
||||
self.register_image(image_info)
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||
continue
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
print(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
|
||||
if len(metadata) < 1:
|
||||
print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(subset.image_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||
abs_path = abs_path[0]
|
||||
|
||||
caption = img_md.get('caption')
|
||||
tags = img_md.get('tags')
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ', ' + tags
|
||||
tags_list.append(tags)
|
||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get('train_resolution')
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = not (self.color_aug or self.random_crop)
|
||||
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if self.flip_aug:
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
@@ -891,7 +961,7 @@ class FineTuningDataset(BaseDataset):
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
||||
if self.flip_aug:
|
||||
if flip_aug_in_subset:
|
||||
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
@@ -937,7 +1007,7 @@ class FineTuningDataset(BaseDataset):
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, image_key):
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + '.npz'
|
||||
|
||||
@@ -949,8 +1019,8 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
||||
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz')
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz')
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
@@ -961,6 +1031,49 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
# behave as Dataset mock
|
||||
class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
||||
self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]]
|
||||
|
||||
super().__init__(datasets)
|
||||
|
||||
self.image_data = {}
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
|
||||
# simply concat together
|
||||
# TODO: handling image_data key duplication among dataset
|
||||
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
|
||||
for dataset in datasets:
|
||||
self.image_data.update(dataset.image_data)
|
||||
self.num_train_images += dataset.num_train_images
|
||||
self.num_reg_images += dataset.num_reg_images
|
||||
|
||||
def add_replacement(self, str_from, str_to):
|
||||
for dataset in self.datasets:
|
||||
dataset.add_replacement(str_from, str_to)
|
||||
|
||||
def make_buckets(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.make_buckets()
|
||||
|
||||
def cache_latents(self, vae):
|
||||
for dataset in self.datasets:
|
||||
dataset.cache_latents(vae)
|
||||
|
||||
def is_latent_cachable(self) -> bool:
|
||||
return all([dataset.is_latent_cachable() for dataset in self.datasets])
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_epoch(epoch)
|
||||
|
||||
def disable_token_padding(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.disable_token_padding()
|
||||
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
@@ -1489,7 +1602,7 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
||||
parser.add_argument("--caption_extention", type=str, default=None,
|
||||
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
||||
parser.add_argument("--keep_tokens", type=int, default=None,
|
||||
parser.add_argument("--keep_tokens", type=int, default=0,
|
||||
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
||||
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
||||
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
||||
@@ -1515,11 +1628,11 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
||||
parser.add_argument("--caption_dropout_rate", type=float, default=0.0,
|
||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0,
|
||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
||||
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0,
|
||||
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
||||
|
||||
if support_dreambooth:
|
||||
@@ -1787,10 +1900,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
args.caption_extension = args.caption_extention
|
||||
args.caption_extention = None
|
||||
|
||||
if args.cache_latents:
|
||||
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
||||
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
||||
|
||||
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
||||
if args.resolution is not None:
|
||||
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
||||
|
||||
Reference in New Issue
Block a user