From cc3d40ca44a78aa99493581ad4681493f2775e7b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 7 Jul 2023 21:16:41 +0900 Subject: [PATCH] support sdxl in prepare scipt --- finetune/prepare_buckets_latents.py | 46 +++++++++++++----- library/train_util.py | 75 +++++++++++++++++------------ 2 files changed, 77 insertions(+), 44 deletions(-) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index fd289d1d..6bb1c32f 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -34,12 +34,18 @@ def collate_fn_remove_corrupted(batch): return batch -def get_latents(vae, images, weight_dtype): - img_tensors = [IMAGE_TRANSFORMS(image) for image in images] +def get_latents(vae, key_and_images, weight_dtype): + img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images] img_tensors = torch.stack(img_tensors) img_tensors = img_tensors.to(DEVICE, weight_dtype) with torch.no_grad(): - latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() + latents = vae.encode(img_tensors).latent_dist.sample() + + # check NaN + for (key, _), latents1 in zip(key_and_images, latents): + if torch.isnan(latents1).any(): + raise ValueError(f"NaN detected in latents of {key}") + return latents @@ -107,24 +113,26 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - latents = get_latents(vae, [img for _, img in bucket], weight_dtype) + latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype) assert ( latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8 ), f"latent shape {latents.shape}, {bucket[0][1].shape}" - for (image_key, _), latent in zip(bucket, latents): + for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents): npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) - np.savez(npz_file_name, latent) + train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top) # flip if args.flip_aug: - latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない + latents = get_latents( + vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype + ) # copyがないとTensor変換できない - for (image_key, _), latent in zip(bucket, latents): + for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents): npz_file_name = get_npz_filename_wo_ext( args.train_data_dir, image_key, args.full_path, True, args.recursive ) - np.savez(npz_file_name, latent) + train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top) else: # remove existing flipped npz for image_key, _ in bucket: @@ -194,7 +202,7 @@ def main(args): resized_size[0] >= reso[0] and resized_size[1] >= reso[1] ), f"internal error resized size is small: {resized_size}, {reso}" - # 既に存在するファイルがあればshapeを確認して同じならskipする + # 既に存在するファイルがあればshape等を確認して同じならskipする if args.skip_existing: npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] if args.flip_aug: @@ -208,8 +216,12 @@ def main(args): found = False break - dat = np.load(npz_file)["arr_0"] - if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 + latents, _, _ = train_util.load_latents_from_disk(npz_file) + if latents is None: # old version + found = False + break + + if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認 found = False break if found: @@ -221,13 +233,21 @@ def main(args): if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + trim_left = 0 if resized_size[0] > reso[0]: trim_size = resized_size[0] - reso[0] image = image[:, trim_size // 2 : trim_size // 2 + reso[0]] + trim_left = trim_size // 2 + trim_top = 0 if resized_size[1] > reso[1]: trim_size = resized_size[1] - reso[1] image = image[trim_size // 2 : trim_size // 2 + reso[1]] + trim_top = trim_size // 2 + + original_size_wh = (resized_size[0], resized_size[1]) + # target_size_wh = (reso[0], reso[1]) + crop_left_top = (trim_left, trim_top) assert ( image.shape[0] == reso[1] and image.shape[1] == reso[0] @@ -237,7 +257,7 @@ def main(args): # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) # バッチへ追加 - bucket_manager.add_image(reso, (image_key, image)) + bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top)) # バッチを推論するか判定して推論する process_batch(False) diff --git a/library/train_util.py b/library/train_util.py index 9ab1f538..62cd145e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -124,11 +124,11 @@ class BucketManager: self.resos = [] self.reso_to_id = {} - self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key + self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key - def add_image(self, reso, image): + def add_image(self, reso, image_or_info): bucket_id = self.reso_to_id[reso] - self.buckets[bucket_id].append(image) + self.buckets[bucket_id].append(image_or_info) def shuffle(self): for bucket in self.buckets: @@ -767,7 +767,10 @@ class BaseDataset(torch.utils.data.Dataset): img = np.array(image, np.uint8) return img - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top) + def trim_and_resize_if_required( + self, subset: BaseSubset, image, reso, resized_size + ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]: image_height, image_width = image.shape[0:2] if image_width != resized_size[0] or image_height != resized_size[1]: @@ -907,19 +910,13 @@ class BaseDataset(torch.utils.data.Dataset): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - # check NaN - for info, latents1 in zip(batch, latents): - if torch.isnan(latents1).any(): + for info, latent in zip(batch, latents): + # check NaN + if torch.isnan(latents).any(): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") - for info, latent in zip(batch, latents): if cache_to_disk: - np.savez( - info.latents_npz, - latents=latent.float().numpy(), - original_size=np.array(info.latents_original_size), - crop_left_top=np.array(info.latents_crop_left_top), - ) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top) else: info.latents = latent @@ -927,12 +924,14 @@ class BaseDataset(torch.utils.data.Dataset): img_tensors = torch.flip(img_tensors, dims=[3]) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") for info, latent in zip(batch, latents): + # check NaN + if torch.isnan(latents).any(): + raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") + if cache_to_disk: - np.savez( - info.latents_npz_flipped, - latents=latent.float().numpy(), - original_size=np.array(info.latents_original_size), - crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents + # crop_left_top is reversed when making batch + save_latents_to_disk( + info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top ) else: info.latents_flipped = latent @@ -1005,18 +1004,7 @@ class BaseDataset(torch.utils.data.Dataset): def load_latents_from_npz(self, image_info: ImageInfo, flipped): npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - if npz_file is None: - return None, None, None - - npz = np.load(npz_file) - if "latents" not in npz: - print(f"error: npz is old format. please re-generate {npz_file}") - return None, None, None - - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_left_top = npz["crop_left_top"].tolist() - return latents, original_size, crop_left_top + return load_latents_from_disk(npz_file) def __len__(self): return self._length @@ -1762,6 +1750,31 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.disable_token_padding() +# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]: + if npz_path is None: # flipped doesn't exist + return None, None, None + + npz = np.load(npz_path) + if "latents" not in npz: + print(f"error: npz is old format. please re-generate {npz_path}") + return None, None, None + + latents = npz["latents"] + original_size = npz["original_size"].tolist() + crop_left_top = npz["crop_left_top"].tolist() + return latents, original_size, crop_left_top + + +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top): + np.savez( + npz_path, + latents=latents_tensor.float().cpu().numpy(), + original_size=np.array(original_size), + crop_left_top=np.array(crop_left_top), + ) + + def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")