update dataset to return size, refactor ctrlnet ds

This commit is contained in:
Kohya S
2023-06-24 17:56:02 +09:00
parent f7f762c676
commit 9e9df2b501
3 changed files with 333 additions and 304 deletions

View File

@@ -103,6 +103,9 @@ class ImageInfo:
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_npz_flipped: str = None
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
self.cond_img_path: str = None
class BucketManager:
@@ -171,6 +174,7 @@ class BucketManager:
def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height
if not self.no_upscale:
# 拡大および縮小を行う
# 同じaspect ratioがあるかもしれないのでfine tuningで、no_upscale=Trueで前処理した場合、解像度が同じものを優先する
reso = (image_width, image_height)
if reso in self.predefined_resos_set:
@@ -189,6 +193,7 @@ class BucketManager:
resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5))
# print("use predef", image_width, image_height, reso, resized_size)
else:
# 縮小のみを行う
if image_width * image_height > self.max_area:
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
resized_width = math.sqrt(self.max_area * aspect_ratio)
@@ -238,41 +243,40 @@ class BucketBatchIndex(NamedTuple):
class AugHelper:
def __init__(self):
# prepare all possible augmentators
color_aug_method = albu.OneOf(
self.color_aug_method = albu.OneOf(
[
albu.HueSaturationValue(8, 0, 0, p=0.5),
albu.RandomGamma((95, 105), p=0.5),
],
p=0.33,
)
flip_aug_method = albu.HorizontalFlip(p=0.5)
# key: (use_color_aug, use_flip_aug)
self.augmentors = {
(True, True): albu.Compose(
[
color_aug_method,
flip_aug_method,
],
p=1.0,
),
(True, False): albu.Compose(
[
color_aug_method,
],
p=1.0,
),
(False, True): albu.Compose(
[
flip_aug_method,
],
p=1.0,
),
(False, False): None,
}
# self.augmentors = {
# (True, True): albu.Compose(
# [
# color_aug_method,
# flip_aug_method,
# ],
# p=1.0,
# ),
# (True, False): albu.Compose(
# [
# color_aug_method,
# ],
# p=1.0,
# ),
# (False, True): albu.Compose(
# [
# flip_aug_method,
# ],
# p=1.0,
# ),
# (False, False): None,
# }
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
return self.augmentors[(use_color_aug, use_flip_aug)]
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]:
return self.color_aug_method if use_color_aug else None
class BaseSubset:
@@ -454,10 +458,16 @@ class ControlNetSubset(BaseSubset):
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
self,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
debug_dataset: bool,
) -> None:
super().__init__()
self.tokenizer = tokenizer
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
@@ -478,7 +488,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_no_upscale = None
self.bucket_info = None # for metadata
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
@@ -594,48 +604,49 @@ class BaseDataset(torch.utils.data.Dataset):
return caption
def get_input_ids(self, caption):
input_ids = self.tokenizer(
def get_input_ids(self, caption, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizers[0]
input_ids = tokenizer(
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
).input_ids
if self.tokenizer_max_length > self.tokenizer.model_max_length:
if self.tokenizer_max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(
1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2
1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2
): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + self.tokenizer.model_max_length - 2],
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(
1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2
):
for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + self.tokenizer.model_max_length - 2],
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
ids_chunk[-1] = self.tokenizer.eos_token_id
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == self.tokenizer.pad_token_id:
ids_chunk[1] = self.tokenizer.eos_token_id
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
@@ -755,46 +766,58 @@ class BaseDataset(torch.utils.data.Dataset):
img = np.array(image, np.uint8)
return img
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None):
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if exists(cond_img):
cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA)
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
crop_left_top = (0, 0)
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
if exists(cond_img):
cond_img = cond_img[:, p : p + reso[0]]
crop_left_top = (p, 0)
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
if exists(cond_img):
cond_img = cond_img[p : p + reso[1]]
crop_left_top = (crop_left_top[0], p)
assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
if exists(cond_img):
assert (
cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}"
return image, cond_img
return image
return image, original_size, crop_left_top
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
for npath in [npz_path, flipped_npz_path]:
if npath is None:
continue
if not os.path.exists(npath):
return False
npz = np.load(npath)
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
return False
cached_latents = npz["latents"]
if cached_latents.shape[1:3] != expected_latents_size:
return False
return True
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# ちょっと速くした
print("caching latents.")
@@ -811,38 +834,26 @@ class BaseDataset(torch.utils.data.Dataset):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
# might be None, but that's ok because check is done in dataset
info.latents_flipped = self.load_latents_from_npz(info, True)
info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None
if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
# check disk cache exists and size of latents
if cache_to_disk:
# TODO: refactor to unify with FineTuningDataset
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
if not is_main_process:
continue
cache_available = False
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
if os.path.exists(info.latents_npz):
cached_latents = np.load(info.latents_npz)["arr_0"]
if cached_latents.shape[1:3] == expected_latents_size:
cache_available = True
cache_available = self.is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
)
if subset.flip_aug:
cache_available = False
if os.path.exists(info.latents_npz_flipped):
cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"]
if cached_latents_flipped.shape[1:3] == expected_latents_size:
cache_available = True
if cache_available:
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
@@ -868,10 +879,15 @@ class BaseDataset(torch.utils.data.Dataset):
images = []
for info in batch:
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
image, original_size, crop_left_top = self.trim_and_resize_if_required(
subset, image, info.bucket_reso, info.resized_size
)
image = self.image_transforms(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_left_top = crop_left_top
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
@@ -879,7 +895,12 @@ class BaseDataset(torch.utils.data.Dataset):
for info, latent in zip(batch, latents):
if cache_to_disk:
np.savez(info.latents_npz, latent.float().numpy())
np.savez(
info.latents_npz,
latents=latent.float().numpy(),
original_size=np.array(info.latents_original_size),
crop_left_top=np.array(info.latents_crop_left_top),
)
else:
info.latents = latent
@@ -888,7 +909,12 @@ class BaseDataset(torch.utils.data.Dataset):
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
if cache_to_disk:
np.savez(info.latents_npz_flipped, latent.float().numpy())
np.savez(
info.latents_npz_flipped,
latents=latent.float().numpy(),
original_size=np.array(info.latents_original_size),
crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents
)
else:
info.latents_flipped = latent
@@ -961,8 +987,13 @@ class BaseDataset(torch.utils.data.Dataset):
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
if npz_file is None:
return None
return np.load(npz_file)["arr_0"]
return None, None, None
npz = np.load(npz_file)
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_left_top = npz["crop_left_top"].tolist()
return latents, original_size, crop_left_top
def __len__(self):
return self._length
@@ -975,21 +1006,35 @@ class BaseDataset(torch.utils.data.Dataset):
loss_weights = []
captions = []
input_ids_list = []
input_ids2_list = []
latents_list = []
images = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
flippeds = [] # 変数名が微妙
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
original_size = image_info.latents_original_size
crop_left_top = image_info.latents_crop_left_top # calc values later if flipped
if not flipped:
latents = image_info.latents
else:
latents = image_info.latents_flipped
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped)
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
@@ -997,7 +1042,9 @@ class BaseDataset(torch.utils.data.Dataset):
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
img, original_size, crop_left_top = self.trim_and_resize_if_required(
subset, img, image_info.bucket_reso, image_info.resized_size
)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
@@ -1017,17 +1064,33 @@ class BaseDataset(torch.utils.data.Dataset):
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
original_size = [im_w, im_h]
crop_left_top = [0, 0]
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
aug = self.aug_helper.get_augmentor(subset.color_aug)
if aug is not None:
img = aug(image=img)["image"]
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if flipped:
crop_left_top = (original_size[0] - crop_left_top[0] - target_size[0], crop_left_top[1])
original_sizes_hw.append((original_size[1], original_size[0]))
crop_top_lefts.append((crop_left_top[1], crop_left_top[0]))
target_sizes_hw.append((target_size[1], target_size[0]))
flippeds.append(flipped)
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
@@ -1039,22 +1102,33 @@ class BaseDataset(torch.utils.data.Dataset):
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
else:
token_caption = self.get_input_ids(caption)
token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption)
if len(self.tokenizers) > 1:
if self.XTI_layers:
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
else:
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2_list.append(token_caption2)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
if len(self.tokenizers) > 1:
# following may not work in SDXL, keep the line for future update
example["input_ids2"] = self.tokenizer[1](captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
# batch processing seems to be good
example["input_ids"] = torch.stack(input_ids_list)
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
if images[0] is not None:
images = torch.stack(images)
@@ -1066,6 +1140,11 @@ class BaseDataset(torch.utils.data.Dataset):
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts])
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
example["flippeds"] = flippeds
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -1462,151 +1541,86 @@ class ControlNetDataset(BaseDataset):
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset) -> None:
debug_dataset,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.conditioning_image_data: Dict[str, ImageInfo] = {}
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
db_subsets = []
for subset in subsets:
db_subset = DreamBoothSubset(
subset.image_dir,
False,
None,
subset.caption_extension,
subset.num_repeats,
subset.shuffle_caption,
subset.keep_tokens,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
subset.random_crop,
subset.caption_dropout_rate,
subset.caption_dropout_every_n_epochs,
subset.caption_tag_dropout_rate,
subset.token_warmup_min,
subset.token_warmup_step,
)
db_subsets.append(db_subset)
self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets,
batch_size,
tokenizer,
max_token_length,
resolution,
enable_bucket,
min_bucket_reso,
max_bucket_reso,
bucket_reso_steps,
bucket_no_upscale,
1.0,
debug_dataset,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.latents_cache = None
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = 0
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
# assert all conditioning data exists
missing_imgs = []
cond_imgs_with_img = set()
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
subset = None
for s in subsets:
if s.image_dir == db_subset.image_dir:
subset = s
break
return caption
assert subset is not None, "internal error: subset not found"
def load_controlnet_dir(subset: ControlNetSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
if not os.path.isdir(subset.conditioning_data_dir):
print(f"not directory: {subset.conditioning_data_dir}")
return [], []
continue
img_paths = glob_images(subset.image_dir, "*")
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
img_paths = sorted(img_paths)
conditioning_img_paths = sorted(conditioning_img_paths)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
img_basename = os.path.basename(info.absolute_path)
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
if not os.path.exists(ctrl_img_path):
missing_imgs.append(img_basename)
img_basenames = [os.path.basename(img) for img in img_paths]
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
missing_imgs = []
extra_imgs = []
info.cond_img_path = ctrl_img_path
cond_imgs_with_img.add(ctrl_img_path)
for img in img_basenames:
if img not in conditioning_img_basenames:
missing_imgs.append(img)
for img in conditioning_img_basenames:
if img not in img_basenames:
extra_imgs.append(img)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption + f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, conditioning_img_paths, captions
print("prepare images.")
num_train_images = 0
extra_imgs = []
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
extra_imgs.extend(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
if subset in self.subsets:
print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
if len(img_paths) < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
continue
num_train_images += subset.num_repeats * len(img_paths)
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
setattr(info, "cond_img_path", cond_img_path)
self.register_image(info, subset)
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
self.conditioning_image_transforms = transforms.Compose(
[
@@ -1614,88 +1628,58 @@ class ControlNetDataset(BaseDataset):
]
)
def __getitem__(self, index):
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
def __len__(self):
return self.dreambooth_dataset_delegate.__len__()
def __getitem__(self, index):
example = self.dreambooth_dataset_delegate[index]
bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[
self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index
]
bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size
image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size
loss_weights = []
captions = []
input_ids_list = []
latents_list = []
images = []
conditioning_images = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(1.0)
for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]):
image_info = self.dreambooth_dataset_delegate.image_data[image_key]
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
target_size_hw = example["target_sizes_hw"][i]
original_size_hw = example["original_sizes_hw"][i]
crop_top_left = example["crop_top_lefts"][i]
flipped = example["flippeds"][i]
cond_img = self.load_image(image_info.cond_img_path)
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents)
image = None
if self.dreambooth_dataset_delegate.enable_bucket:
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
ct, cl = crop_top_left
h, w = target_size_hw
cond_img = cond_img[ct : ct + h, cl : cl + w]
else:
# 画像を読み込み、必要ならcropする
img = self.load_image(image_info.absolute_path)
cond_img = self.load_image(image_info.cond_img_path)
im_h, im_w = img.shape[0:2]
assert (
cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
if self.enable_bucket:
img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img)
else:
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
if aug is not None:
img = aug(image=img)["image"]
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
caption = self.process_caption(subset, image_info.caption)
captions.append(caption)
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
example["input_ids"] = torch.stack(input_ids_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
return example
# behave as Dataset mock
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
@@ -1773,18 +1757,42 @@ def debug_dataset(train_dataset, show_input_ids=False):
example = train_dataset[idx]
if example["latents"] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
zip(
example["image_keys"],
example["captions"],
example["loss_weights"],
example["input_ids"],
example["original_sizes_hw"],
example["crop_top_lefts"],
example["target_sizes_hw"],
example["flippeds"],
)
):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop left top: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if show_input_ids:
print(f"input ids: {iid}")
if "input_ids2" in example:
print(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if "conditioning_images" in example:
cond_img = example["conditioning_images"][j]
print(f"conditioning image size: {cond_img.size()}")
cond_img = (cond_img.numpy() * 255.0).astype(np.uint8)
cond_img = np.transpose(cond_img, (1, 2, 0))
cond_img = cond_img[:, :, ::-1]
if os.name == "nt":
cv2.imshow("cond_img", cond_img)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
@@ -2011,7 +2019,6 @@ def get_git_revision_hash() -> str:
return "(unknown)"
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
# replace_attentions_for_hypernetwork()
# # unet is not used currently, but it is here for future use
@@ -2063,8 +2070,9 @@ def get_git_revision_hash() -> str:
# out = self.to_out[1](out)
# return out
# diffusers.models.attention.CrossAttention.forward = forward_xformers
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
unet.set_use_memory_efficient_attention(False, True)
@@ -2080,6 +2088,7 @@ def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa
print("Enable SDPA for U-Net")
unet.set_use_sdpa(True)
"""
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
# vae is not used currently, but it is here for future use
@@ -2327,7 +2336,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使うPyTorch 2.0が必要)")
parser.add_argument(
"--sdpa",
action="store_true",
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使うPyTorch 2.0が必要)",
)
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
)
@@ -3231,7 +3244,9 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
)
else:
# Diffusers model is loaded to CPU
print(f"load Diffusers pretrained models: {name_or_path}")
@@ -3281,7 +3296,10 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
args,
weight_dtype,
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
@@ -3595,7 +3613,17 @@ SCHEDLER_SCHEDULE = "scaled_linear"
def sample_images(
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
accelerator,
args: argparse.Namespace,
epoch,
steps,
device,
vae,
tokenizer,
text_encoder,
unet,
prompt_replacement=None,
controlnet=None,
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
@@ -3690,7 +3718,7 @@ def sample_images(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
pipeline.to(device)
save_dir = args.output_dir + "/sample"
@@ -3765,7 +3793,6 @@ def sample_images(
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)