From 94c151aea33917ed1a0e938b9f571f83c1e17165 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 15 Jul 2023 18:28:33 +0900 Subject: [PATCH] refactor caching latents (flip in same npz, etc) --- finetune/prepare_buckets_latents.py | 114 +--------- library/train_util.py | 341 +++++++++++++++++----------- tools/cache_latents.py | 193 ++++++++++++++++ 3 files changed, 409 insertions(+), 239 deletions(-) create mode 100644 tools/cache_latents.py diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 6bb1c32f..1dde2294 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -34,22 +34,7 @@ def collate_fn_remove_corrupted(batch): return batch -def get_latents(vae, key_and_images, weight_dtype): - img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images] - img_tensors = torch.stack(img_tensors) - img_tensors = img_tensors.to(DEVICE, weight_dtype) - with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample() - - # check NaN - for (key, _), latents1 in zip(key_and_images, latents): - if torch.isnan(latents1).any(): - raise ValueError(f"NaN detected in latents of {key}") - - return latents - - -def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): +def get_npz_filename(data_dir, image_key, is_full_path, recursive): if is_full_path: base_name = os.path.splitext(os.path.basename(image_key))[0] relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) @@ -57,13 +42,10 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive): base_name = image_key relative_path = "" - if flip: - base_name += "_flip" - if recursive and relative_path: - return os.path.join(data_dir, relative_path, base_name) + return os.path.join(data_dir, relative_path, base_name) + ".npz" else: - return os.path.join(data_dir, base_name) + return os.path.join(data_dir, base_name) + ".npz" def main(args): @@ -113,36 +95,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype) - assert ( - latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8 - ), f"latent shape {latents.shape}, {bucket[0][1].shape}" - - for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) - train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top) - - # flip - if args.flip_aug: - latents = get_latents( - vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype - ) # copyがないとTensor変換できない - - for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents): - npz_file_name = get_npz_filename_wo_ext( - args.train_data_dir, image_key, args.full_path, True, args.recursive - ) - train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top) - else: - # remove existing flipped npz - for image_key, _ in bucket: - npz_file_name = ( - get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" - ) - if os.path.isfile(npz_file_name): - print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") - os.remove(npz_file_name) - + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -203,61 +156,18 @@ def main(args): ), f"internal error resized size is small: {resized_size}, {reso}" # 既に存在するファイルがあればshape等を確認して同じならskipする + npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive) if args.skip_existing: - npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] - if args.flip_aug: - npz_files.append( - get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz" - ) - - found = True - for npz_file in npz_files: - if not os.path.exists(npz_file): - found = False - break - - latents, _, _ = train_util.load_latents_from_disk(npz_file) - if latents is None: # old version - found = False - break - - if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認 - found = False - break - if found: + if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug): continue - # 画像をリサイズしてトリミングする - # PILにinter_areaがないのでcv2で…… - image = np.array(image) - if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) - - trim_left = 0 - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size // 2 : trim_size // 2 + reso[0]] - trim_left = trim_size // 2 - - trim_top = 0 - if resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size // 2 : trim_size // 2 + reso[1]] - trim_top = trim_size // 2 - - original_size_wh = (resized_size[0], resized_size[1]) - # target_size_wh = (reso[0], reso[1]) - crop_left_top = (trim_left, trim_top) - - assert ( - image.shape[0] == reso[1] and image.shape[1] == reso[0] - ), f"internal error, illegal trimmed size: {image.shape}, {reso}" - - # # debug - # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) - # バッチへ追加 - bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top)) + image_info = train_util.ImageInfo(image_key, 1, "", False, image_path) + image_info.latents_npz = npz_file_name + image_info.bucket_reso = reso + image_info.resized_size = resized_size + image_info.image = image + bucket_manager.add_image(reso, image_info) # バッチを推論するか判定して推論する process_batch(False) diff --git a/library/train_util.py b/library/train_util.py index 15c58cc8..bf333299 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -50,6 +50,7 @@ from diffusers import ( HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, + AutoencoderKL, ) from library import custom_train_functions from library.original_unet import UNet2DConditionModel @@ -96,6 +97,13 @@ try: except: pass +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -110,10 +118,10 @@ class ImageInfo: self.latents: torch.Tensor = None self.latents_flipped: torch.Tensor = None self.latents_npz: str = None - self.latents_npz_flipped: str = None self.latents_original_size: Tuple[int, int] = None # original image size, not latents size self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top self.cond_img_path: str = None + self.image: Optional[Image.Image] = None # optional, original PIL Image class BucketManager: @@ -507,21 +515,22 @@ class BaseDataset(torch.utils.data.Dataset): # augmentation self.aug_helper = AugHelper() - self.image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + self.image_transforms = IMAGE_TRANSFORMS self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} self.replacements = {} + # caching + self.caching_mode = None # None, 'latents', 'text' + def set_seed(self, seed): self.seed = seed + def set_caching_mode(self, mode): + self.caching_mode = mode + def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする self.shuffle_buckets() @@ -767,45 +776,6 @@ class BaseDataset(torch.utils.data.Dataset): random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top) - def trim_and_resize_if_required( - self, subset: BaseSubset, image, reso, resized_size - ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: - image_height, image_width = image.shape[0:2] - - if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - - image_height, image_width = image.shape[0:2] - original_size = (image_width, image_height) - - crop_left_top = (0, 0) - if image_width > reso[0]: - trim_size = image_width - reso[0] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) - image = image[:, p : p + reso[0]] - crop_left_top = (p, 0) - if image_height > reso[1]: - trim_size = image_height - reso[1] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) - image = image[p : p + reso[1]] - crop_left_top = (crop_left_top[0], p) - - assert ( - image.shape[0] == reso[1] and image.shape[1] == reso[0] - ), f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image, original_size, crop_left_top - def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) @@ -822,26 +792,6 @@ class BaseDataset(torch.utils.data.Dataset): ] ) - def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path): - expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 - - for npath in [npz_path, flipped_npz_path]: - if npath is None: - continue - if not os.path.exists(npath): - return False - - npz = np.load(npath) - if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver? - return False - - cached_latents = npz["latents"] - - if cached_latents.shape[1:3] != expected_latents_size: - return False - - return True - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # ちょっと速くした print("caching latents.") @@ -864,13 +814,10 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if cache_to_disk: info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" - info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" - if not is_main_process: + if not is_main_process: # store to info only continue - cache_available = self.is_disk_cached_latents_is_expected( - info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None - ) + cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) if cache_available: # do not add to batch continue @@ -890,60 +837,19 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) > 0: batches.append(batch) - if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return - # iterate batches + # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded for batch in tqdm(batches, smoothing=1, total=len(batches)): - images = [] - for info in batch: - image = self.load_image(info.absolute_path) - image, original_size, crop_left_top = self.trim_and_resize_if_required( - subset, image, info.bucket_reso, info.resized_size - ) - image = self.image_transforms(image) - images.append(image) - - info.latents_original_size = original_size - info.latents_crop_left_top = crop_left_top - - img_tensors = torch.stack(images, dim=0) - img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) - - latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - - for info, latent in zip(batch, latents): - # check NaN - if torch.isnan(latents).any(): - raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") - - if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top) - else: - info.latents = latent - - if subset.flip_aug: - img_tensors = torch.flip(img_tensors, dims=[3]) - latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - for info, latent in zip(batch, latents): - # check NaN - if torch.isnan(latents).any(): - raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") - - if cache_to_disk: - # crop_left_top is reversed when making batch - save_latents_to_disk( - info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top - ) - else: - info.latents_flipped = latent + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) def get_image_size(self, image_path): image = Image.open(image_path) return image.size def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = self.load_image(image_path) + img = load_image(image_path) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1004,10 +910,6 @@ class BaseDataset(torch.utils.data.Dataset): return image - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - return load_latents_from_disk(npz_file) - def __len__(self): return self._length @@ -1016,6 +918,9 @@ class BaseDataset(torch.utils.data.Dataset): bucket_batch_size = self.buckets_indices[index].bucket_batch_size image_index = self.buckets_indices[index].batch_index * bucket_batch_size + if self.caching_mode is not None: # return batch for latents/text encoder outputs caching + return self.get_item_for_caching(bucket, bucket_batch_size, image_index) + loss_weights = [] captions = [] input_ids_list = [] @@ -1045,7 +950,10 @@ class BaseDataset(torch.utils.data.Dataset): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped) + latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz) + if flipped: + latents = flipped_latents + del flipped_latents latents = torch.FloatTensor(latents) image = None @@ -1055,8 +963,8 @@ class BaseDataset(torch.utils.data.Dataset): im_h, im_w = img.shape[0:2] if self.enable_bucket: - img, original_size, crop_left_top = self.trim_and_resize_if_required( - subset, img, image_info.bucket_reso, image_info.resized_size + img, original_size, crop_left_top = trim_and_resize_if_required( + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size ) else: if face_cx > 0: # 顔位置情報あり @@ -1162,6 +1070,53 @@ class BaseDataset(torch.utils.data.Dataset): example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example + def get_item_for_caching(self, bucket, bucket_batch_size, image_index): + captions = [] + images = [] + absolute_paths = [] + resized_sizes = [] + bucket_reso = None + flip_aug = None + random_crop = None + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + + if flip_aug is None: + flip_aug = subset.flip_aug + random_crop = subset.random_crop + bucket_reso = image_info.bucket_reso + else: + assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert random_crop == subset.random_crop, "random_crop must be same in a batch" + assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" + + caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc. + if self.caching_mode == "latents": + image = load_image(image_info.absolute_path) + else: + image = None + + captions.append(caption) + images.append(image) + absolute_paths.append(image_info.absolute_path) + resized_sizes.append(image_info.resized_size) + + example = {} + + if images[0] is None: + images = None + example["images"] = images + + example["captions"] = captions + example["absolute_paths"] = absolute_paths + example["resized_sizes"] = resized_sizes + example["flip_aug"] = flip_aug + example["random_crop"] = random_crop + example["bucket_reso"] = bucket_reso + return example + class DreamBoothDataset(BaseDataset): def __init__( @@ -1635,11 +1590,7 @@ class ControlNetDataset(BaseDataset): assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" - self.conditioning_image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - ] - ) + self.conditioning_image_transforms = IMAGE_TRANSFORMS def make_buckets(self): self.dreambooth_dataset_delegate.make_buckets() @@ -1667,7 +1618,7 @@ class ControlNetDataset(BaseDataset): original_size_hw = example["original_sizes_hw"][i] crop_top_left = example["crop_top_lefts"][i] flipped = example["flippeds"][i] - cond_img = self.load_image(image_info.cond_img_path) + cond_img = load_image(image_info.cond_img_path) if self.dreambooth_dataset_delegate.enable_bucket: cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ @@ -1729,6 +1680,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): print(f"[Dataset {i}]") dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def set_caching_mode(self, caching_mode): + for dataset in self.datasets: + dataset.set_caching_mode(caching_mode) + def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -1752,28 +1707,53 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.disable_token_padding() -# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) -def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]: - if npz_path is None: # flipped doesn't exist - return None, None, None +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): + expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 + if not os.path.exists(npz_path): + return False + + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver? + return False + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + return True + + +# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +def load_latents_from_disk( + npz_path, +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: npz = np.load(npz_path) if "latents" not in npz: print(f"error: npz is old format. please re-generate {npz_path}") - return None, None, None + return None, None, None, None latents = npz["latents"] original_size = npz["original_size"].tolist() crop_left_top = npz["crop_left_top"].tolist() - return latents, original_size, crop_left_top + flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None + return latents, original_size, crop_left_top, flipped_latents -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None): + kwargs = {} + if flipped_latents_tensor is not None: + kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), original_size=np.array(original_size), crop_left_top=np.array(crop_left_top), + **kwargs, ) @@ -1948,6 +1928,93 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group +def load_image(image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + +# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top) +def trim_and_resize_if_required( + random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] +) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: + image_height, image_width = image.shape[0:2] + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサイズする + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + image_height, image_width = image.shape[0:2] + original_size = (image_width, image_height) + + crop_left_top = (0, 0) + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p : p + reso[0]] + crop_left_top = (p, 0) + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p : p + reso[1]] + crop_left_top = (crop_left_top[0], p) + + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image, original_size, crop_left_top + + +def cache_batch_latents( + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool +) -> None: + r""" + requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz + optionally requires image_infos to have: image + if cache_to_disk is True, set info.latents_npz + flipped latents is also saved if flip_aug is True + if cache_to_disk is False, set info.latents + latents_flipped is also set if flip_aug is True + latents_original_size and latents_crop_left_top are also set + """ + images = [] + for info in image_infos: + image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image = IMAGE_TRANSFORMS(image) + images.append(image) + + info.latents_original_size = original_size + info.latents_crop_left_top = crop_left_top + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + + if flip_aug: + img_tensors = torch.flip(img_tensors, dims=[3]) + with torch.no_grad(): + flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + # check NaN + if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): + raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") + + if cache_to_disk: + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent) + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + + # endregion # region モジュール入れ替え部 @@ -3975,7 +4042,7 @@ def sample_images_common( controlnet=controlnet, controlnet_image=controlnet_image, ) - + image = pipeline.latents_to_image(latents)[0] ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) diff --git a/tools/cache_latents.py b/tools/cache_latents.py new file mode 100644 index 00000000..d403d559 --- /dev/null +++ b/tools/cache_latents.py @@ -0,0 +1,193 @@ +# latentsのdiskへの事前キャッシュを行う / cache latents to disk + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import config_util +from library import train_util +from library import sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache latents arg + assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + if args.sdxl: + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizers = [tokenizer1, tokenizer2] + else: + tokenizer = train_util.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + # datasetのcache_latentsを呼ばなければ、生の画像が返る + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + print("load model") + if args.sdxl: + (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + else: + _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + vae.set_use_memory_efficient_attention_xformers(args.xformers) + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("latents") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + b_size = len(batch["images"]) + vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size + flip_aug = batch["flip_aug"] + random_crop = batch["random_crop"] + bucket_reso = batch["bucket_reso"] + + # バッチを分割して処理する + for i in range(0, b_size, vae_batch_size): + images = batch["images"][i : i + vae_batch_size] + absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] + resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] + + image_infos = [] + for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.image = image + image_info.bucket_reso = bucket_reso + image_info.resized_size = resized_size + image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" + + if args.skip_existing: + if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + print(f"Skipping {image_info.latents_npz} because it already exists.") + continue + + image_infos.append(image_info) + + if len(image_infos) > 0: + train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args)